[llvm] [SCEV] Split collecting and applying rewrite info from loop guards (NFC) (PR #97316)

Florian Hahn via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 2 06:45:34 PDT 2024


https://github.com/fhahn updated https://github.com/llvm/llvm-project/pull/97316

>From bc6780b35d9fe07b1bec5ce760ef6194389830ac Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Fri, 28 Jun 2024 14:41:10 +0100
Subject: [PATCH 1/2] [SCEV] Splip collecting and applying rewrite info from
 loop guards (NFC)

Split off collecting rewrite info from loop guards to
collectRewriteInfoFromLoopGuards. This allows users of applyLoopGuards to
collect rewrite info once in cases where the same loop guards are
applied multiple times.

This is used to collect rewrite info once in howFarToZero, which saves a
bit of compile-time:
stage1-O3: -0.04%
stage1-ReleaseThinLTO: -0.02%
stage1-ReleaseLTO-g: -0.04%
stage2-O3: -0.02%
https://llvm-compile-time-tracker.com/compare.php?from=117b53ae38428ca66eaa886fb432e6f09db88fe4&to=4ffb7b2e1c99081ccebe6f236c48a0be2f64b6ff&stat=instructions:u

Notably this improves mafft by -0.9% with -O3, -0.11% with LTO and
-0.12% with stage2-O3.
---
 llvm/include/llvm/Analysis/ScalarEvolution.h |  9 +++++
 llvm/lib/Analysis/ScalarEvolution.cpp        | 38 +++++++++++++++-----
 2 files changed, 38 insertions(+), 9 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 97b30daf4427a..d9173e8745688 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1299,8 +1299,17 @@ class ScalarEvolution {
   /// sharpen it.
   void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags);
 
+  /// Collect rewrite map for loop guards for loop \p L, together with flags
+  /// indidcating if NUW and NSW can be preserved during rewriting.
+  std::tuple<DenseMap<const SCEV *, const SCEV *>, bool, bool>
+  collectRewriteInfoFromLoopGuards(const Loop *L);
+
   /// Try to apply information from loop guards for \p L to \p Expr.
   const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
+  const SCEV *
+  applyLoopGuards(const SCEV *Expr, const Loop *L,
+                  const DenseMap<const SCEV *, const SCEV *> &RewriteMap,
+                  bool PreserveNUW, bool PreserveNSW);
 
   /// Return true if the loop has no abnormal exits. That is, if the loop
   /// is not infinite, it must exit through an explicit edge in the CFG.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e998fe9452ad7..9e39ceb6526e1 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -10490,8 +10490,11 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   if (!isLoopInvariant(Step, L))
     return getCouldNotCompute();
 
+  const auto &[RewriteMap, PreserveNUW, PreserveNSW] =
+      collectRewriteInfoFromLoopGuards(L);
   // Specialize step for this loop so we get context sensitive facts below.
-  const SCEV *StepWLG = applyLoopGuards(Step, L);
+  const SCEV *StepWLG =
+      applyLoopGuards(Step, L, RewriteMap, PreserveNUW, PreserveNSW);
 
   // For positive steps (counting up until unsigned overflow):
   //   N = -Start/Step (as unsigned)
@@ -10508,7 +10511,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   //   N = Distance (as unsigned)
   if (StepC &&
       (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
-    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
+    APInt MaxBECount = getUnsignedRangeMax(
+        applyLoopGuards(Distance, L, RewriteMap, PreserveNUW, PreserveNSW));
     MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
 
     // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
@@ -10549,7 +10553,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
     const SCEV *ConstantMax = getCouldNotCompute();
     if (Exact != getCouldNotCompute()) {
-      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
+      APInt MaxInt = getUnsignedRangeMax(
+          applyLoopGuards(Exact, L, RewriteMap, PreserveNUW, PreserveNSW));
       ConstantMax =
           getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
     }
@@ -10566,7 +10571,8 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
 
   const SCEV *M = E;
   if (E != getCouldNotCompute()) {
-    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
+    APInt MaxWithGuards = getUnsignedRangeMax(
+        applyLoopGuards(E, L, RewriteMap, PreserveNUW, PreserveNSW));
     M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
   }
   auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
@@ -15096,7 +15102,7 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
 
 public:
   SCEVLoopGuardRewriter(ScalarEvolution &SE,
-                        DenseMap<const SCEV *, const SCEV *> &M,
+                        const DenseMap<const SCEV *, const SCEV *> &M,
                         bool PreserveNUW, bool PreserveNSW)
       : SCEVRewriteVisitor(SE), Map(M) {
     if (PreserveNUW)
@@ -15191,7 +15197,8 @@ class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
   }
 };
 
-const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
+std::tuple<DenseMap<const SCEV *, const SCEV *>, bool, bool>
+ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
   SmallVector<const SCEV *> ExprsToRewrite;
   auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
                               const SCEV *RHS,
@@ -15600,9 +15607,6 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
     }
   }
 
-  if (RewriteMap.empty())
-    return Expr;
-
   // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
   // the replacement expressions are contained in the ranges of the replaced
   // expressions.
@@ -15626,6 +15630,22 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
     }
   }
+  return {RewriteMap, PreserveNUW, PreserveNSW};
+}
+
+const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
+  const auto &[RewriteMap, PreserveNUW, PreserveNSW] =
+      collectRewriteInfoFromLoopGuards(L);
+  return applyLoopGuards(Expr, L, RewriteMap, PreserveNUW, PreserveNSW);
+}
+
+const SCEV *ScalarEvolution::applyLoopGuards(
+    const SCEV *Expr, const Loop *L,
+    const DenseMap<const SCEV *, const SCEV *> &RewriteMap, bool PreserveNUW,
+    bool PreserveNSW) {
+  if (RewriteMap.empty())
+    return Expr;
+
   SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW, PreserveNSW);
   return Rewriter.visit(Expr);
 }

>From 78f6143871a7ffab137f1d80391c33b85ad4832a Mon Sep 17 00:00:00 2001
From: Florian Hahn <flo at fhahn.com>
Date: Tue, 2 Jul 2024 14:33:36 +0100
Subject: [PATCH 2/2] !fixup add LoopGuards class, move collect logic there.

---
 llvm/include/llvm/Analysis/ScalarEvolution.h |  25 +-
 llvm/lib/Analysis/ScalarEvolution.cpp        | 344 +++++++++----------
 2 files changed, 189 insertions(+), 180 deletions(-)

diff --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index d9173e8745688..b36f2beb03137 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1299,17 +1299,26 @@ class ScalarEvolution {
   /// sharpen it.
   void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags);
 
-  /// Collect rewrite map for loop guards for loop \p L, together with flags
-  /// indidcating if NUW and NSW can be preserved during rewriting.
-  std::tuple<DenseMap<const SCEV *, const SCEV *>, bool, bool>
-  collectRewriteInfoFromLoopGuards(const Loop *L);
+  class LoopGuards {
+    DenseMap<const SCEV *, const SCEV *> RewriteMap;
+    bool PreserveNUW = false;
+    bool PreserveNSW = false;
+    ScalarEvolution &SE;
+
+    LoopGuards(ScalarEvolution &SE) : SE(SE) {}
+
+  public:
+    /// Collect rewrite map for loop guards for loop \p L, together with flags
+    /// indidcating if NUW and NSW can be preserved during rewriting.
+    static LoopGuards collect(const Loop *L, ScalarEvolution &SE);
+
+    /// Try to apply the collected loop guards to \p Expr
+    const SCEV *rewrite(const SCEV *Expr) const;
+  };
 
   /// Try to apply information from loop guards for \p L to \p Expr.
   const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
-  const SCEV *
-  applyLoopGuards(const SCEV *Expr, const Loop *L,
-                  const DenseMap<const SCEV *, const SCEV *> &RewriteMap,
-                  bool PreserveNUW, bool PreserveNSW);
+  const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Gards);
 
   /// Return true if the loop has no abnormal exits. That is, if the loop
   /// is not infinite, it must exit through an explicit edge in the CFG.
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 9e39ceb6526e1..430e1c6d8f8c6 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -10490,11 +10490,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   if (!isLoopInvariant(Step, L))
     return getCouldNotCompute();
 
-  const auto &[RewriteMap, PreserveNUW, PreserveNSW] =
-      collectRewriteInfoFromLoopGuards(L);
+  LoopGuards Guards = LoopGuards::collect(L, *this);
   // Specialize step for this loop so we get context sensitive facts below.
-  const SCEV *StepWLG =
-      applyLoopGuards(Step, L, RewriteMap, PreserveNUW, PreserveNSW);
+  const SCEV *StepWLG = applyLoopGuards(Step, Guards);
 
   // For positive steps (counting up until unsigned overflow):
   //   N = -Start/Step (as unsigned)
@@ -10511,8 +10509,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   //   N = Distance (as unsigned)
   if (StepC &&
       (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
-    APInt MaxBECount = getUnsignedRangeMax(
-        applyLoopGuards(Distance, L, RewriteMap, PreserveNUW, PreserveNSW));
+    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
     MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
 
     // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
@@ -10553,8 +10550,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
     const SCEV *ConstantMax = getCouldNotCompute();
     if (Exact != getCouldNotCompute()) {
-      APInt MaxInt = getUnsignedRangeMax(
-          applyLoopGuards(Exact, L, RewriteMap, PreserveNUW, PreserveNSW));
+      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
       ConstantMax =
           getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
     }
@@ -10571,8 +10567,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
 
   const SCEV *M = E;
   if (E != getCouldNotCompute()) {
-    APInt MaxWithGuards = getUnsignedRangeMax(
-        applyLoopGuards(E, L, RewriteMap, PreserveNUW, PreserveNSW));
+    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
     M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
   }
   auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
@@ -15092,113 +15087,9 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
   return false;
 }
 
-/// A rewriter to replace SCEV expressions in Map with the corresponding entry
-/// in the map. It skips AddRecExpr because we cannot guarantee that the
-/// replacement is loop invariant in the loop of the AddRec.
-class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
-  const DenseMap<const SCEV *, const SCEV *> ⤅
-
-  SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
-
-public:
-  SCEVLoopGuardRewriter(ScalarEvolution &SE,
-                        const DenseMap<const SCEV *, const SCEV *> &M,
-                        bool PreserveNUW, bool PreserveNSW)
-      : SCEVRewriteVisitor(SE), Map(M) {
-    if (PreserveNUW)
-      FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
-    if (PreserveNSW)
-      FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
-  }
-
-  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
-
-  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
-    auto I = Map.find(Expr);
-    if (I == Map.end())
-      return Expr;
-    return I->second;
-  }
-
-  const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
-    auto I = Map.find(Expr);
-    if (I == Map.end()) {
-      // If we didn't find the extact ZExt expr in the map, check if there's an
-      // entry for a smaller ZExt we can use instead.
-      Type *Ty = Expr->getType();
-      const SCEV *Op = Expr->getOperand(0);
-      unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
-      while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
-             Bitwidth > Op->getType()->getScalarSizeInBits()) {
-        Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
-        auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
-        auto I = Map.find(NarrowExt);
-        if (I != Map.end())
-          return SE.getZeroExtendExpr(I->second, Ty);
-        Bitwidth = Bitwidth / 2;
-      }
-
-      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
-          Expr);
-    }
-    return I->second;
-  }
-
-  const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
-    auto I = Map.find(Expr);
-    if (I == Map.end())
-      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSignExtendExpr(
-          Expr);
-    return I->second;
-  }
-
-  const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
-    auto I = Map.find(Expr);
-    if (I == Map.end())
-      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitUMinExpr(Expr);
-    return I->second;
-  }
-
-  const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
-    auto I = Map.find(Expr);
-    if (I == Map.end())
-      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
-    return I->second;
-  }
-
-  const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
-    SmallVector<const SCEV *, 2> Operands;
-    bool Changed = false;
-    for (const auto *Op : Expr->operands()) {
-      Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
-      Changed |= Op != Operands.back();
-    }
-    // We are only replacing operands with equivalent values, so transfer the
-    // flags from the original expression.
-    return !Changed
-               ? Expr
-               : SE.getAddExpr(Operands, ScalarEvolution::maskFlags(
-                                             Expr->getNoWrapFlags(), FlagMask));
-  }
-
-  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
-    SmallVector<const SCEV *, 2> Operands;
-    bool Changed = false;
-    for (const auto *Op : Expr->operands()) {
-      Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
-      Changed |= Op != Operands.back();
-    }
-    // We are only replacing operands with equivalent values, so transfer the
-    // flags from the original expression.
-    return !Changed
-               ? Expr
-               : SE.getMulExpr(Operands, ScalarEvolution::maskFlags(
-                                             Expr->getNoWrapFlags(), FlagMask));
-  }
-};
-
-std::tuple<DenseMap<const SCEV *, const SCEV *>, bool, bool>
-ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
+ScalarEvolution::LoopGuards
+ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
+  LoopGuards Guards(SE);
   SmallVector<const SCEV *> ExprsToRewrite;
   auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
                               const SCEV *RHS,
@@ -15218,7 +15109,7 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
     // Check for a condition of the form (-C1 + X < C2).  InstCombine will
     // create this form when combining two checks of the form (X u< C2 + C1) and
     // (X >=u C1).
-    auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
+    auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
                                  &ExprsToRewrite]() {
       auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
       if (!AddExpr || AddExpr->getNumOperands() != 2)
@@ -15239,9 +15130,10 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
         return false;
       auto I = RewriteMap.find(LHSUnknown);
       const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
-      RewriteMap[LHSUnknown] = getUMaxExpr(
-          getConstant(ExactRegion.getUnsignedMin()),
-          getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
+      RewriteMap[LHSUnknown] = SE.getUMaxExpr(
+          SE.getConstant(ExactRegion.getUnsignedMin()),
+          SE.getUMinExpr(RewrittenLHS,
+                         SE.getConstant(ExactRegion.getUnsignedMax())));
       ExprsToRewrite.push_back(LHSUnknown);
       return true;
     };
@@ -15294,7 +15186,7 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
       APInt Rem = ExprVal.urem(DivisorVal);
       if (!Rem.isZero())
         // return the SCEV: Expr + Divisor - Expr % Divisor
-        return getConstant(ExprVal + DivisorVal - Rem);
+        return SE.getConstant(ExprVal + DivisorVal - Rem);
       return Expr;
     };
 
@@ -15309,7 +15201,7 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
         return Expr;
       APInt Rem = ExprVal.urem(DivisorVal);
       // return the SCEV: Expr - Expr % Divisor
-      return getConstant(ExprVal - Rem);
+      return SE.getConstant(ExprVal - Rem);
     };
 
     // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
@@ -15325,14 +15217,14 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
             return MinMaxExpr;
           auto IsMin =
               isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
-          assert(isKnownNonNegative(MinMaxLHS) &&
+          assert(SE.isKnownNonNegative(MinMaxLHS) &&
                  "Expected non-negative operand!");
           auto *DivisibleExpr =
               IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
                     : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
           SmallVector<const SCEV *> Ops = {
               ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
-          return getMinMaxExpr(SCTy, Ops);
+          return SE.getMinMaxExpr(SCTy, Ops);
         };
 
     // If we have LHS == 0, check if LHS is computing a property of some unknown
@@ -15344,14 +15236,14 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
       // explicitly express that.
       const SCEV *URemLHS = nullptr;
       const SCEV *URemRHS = nullptr;
-      if (matchURem(LHS, URemLHS, URemRHS)) {
+      if (SE.matchURem(LHS, URemLHS, URemRHS)) {
         if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
           auto I = RewriteMap.find(LHSUnknown);
           const SCEV *RewrittenLHS =
               I != RewriteMap.end() ? I->second : LHSUnknown;
           RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
           const auto *Multiple =
-              getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
+              SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
           RewriteMap[LHSUnknown] = Multiple;
           ExprsToRewrite.push_back(LHSUnknown);
           return;
@@ -15360,7 +15252,7 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
     }
 
     // Do not apply information for constants or if RHS contains an AddRec.
-    if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
+    if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
       return;
 
     // If RHS is SCEVUnknown, make sure the information is applied to it.
@@ -15419,7 +15311,7 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
     // Return true if Expr known to divide by \p DividesBy.
     std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
         [&](const SCEV *Expr, const SCEV *DividesBy) {
-          if (getURemExpr(Expr, DividesBy)->isZero())
+          if (SE.getURemExpr(Expr, DividesBy)->isZero())
             return true;
           if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
             return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
@@ -15445,21 +15337,21 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
     // We cannot express strict predicates in SCEV, so instead we replace them
     // with non-strict ones against plus or minus one of RHS depending on the
     // predicate.
-    const SCEV *One = getOne(RHS->getType());
+    const SCEV *One = SE.getOne(RHS->getType());
     switch (Predicate) {
       case CmpInst::ICMP_ULT:
         if (RHS->getType()->isPointerTy())
           return;
-        RHS = getUMaxExpr(RHS, One);
+        RHS = SE.getUMaxExpr(RHS, One);
         [[fallthrough]];
       case CmpInst::ICMP_SLT: {
-        RHS = getMinusSCEV(RHS, One);
+        RHS = SE.getMinusSCEV(RHS, One);
         RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
         break;
       }
       case CmpInst::ICMP_UGT:
       case CmpInst::ICMP_SGT:
-        RHS = getAddExpr(RHS, One);
+        RHS = SE.getAddExpr(RHS, One);
         RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
         break;
       case CmpInst::ICMP_ULE:
@@ -15493,25 +15385,25 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
       switch (Predicate) {
       case CmpInst::ICMP_ULT:
       case CmpInst::ICMP_ULE:
-        To = getUMinExpr(FromRewritten, RHS);
+        To = SE.getUMinExpr(FromRewritten, RHS);
         if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
           EnqueueOperands(UMax);
         break;
       case CmpInst::ICMP_SLT:
       case CmpInst::ICMP_SLE:
-        To = getSMinExpr(FromRewritten, RHS);
+        To = SE.getSMinExpr(FromRewritten, RHS);
         if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
           EnqueueOperands(SMax);
         break;
       case CmpInst::ICMP_UGT:
       case CmpInst::ICMP_UGE:
-        To = getUMaxExpr(FromRewritten, RHS);
+        To = SE.getUMaxExpr(FromRewritten, RHS);
         if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
           EnqueueOperands(UMin);
         break;
       case CmpInst::ICMP_SGT:
       case CmpInst::ICMP_SGE:
-        To = getSMaxExpr(FromRewritten, RHS);
+        To = SE.getSMaxExpr(FromRewritten, RHS);
         if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
           EnqueueOperands(SMin);
         break;
@@ -15524,7 +15416,7 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
             cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
           const SCEV *OneAlignedUp =
               DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
-          To = getUMaxExpr(FromRewritten, OneAlignedUp);
+          To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
         }
         break;
       default:
@@ -15539,22 +15431,23 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
   BasicBlock *Header = L->getHeader();
   SmallVector<PointerIntPair<Value *, 1, bool>> Terms;
   // First, collect information from assumptions dominating the loop.
-  for (auto &AssumeVH : AC.assumptions()) {
+  for (auto &AssumeVH : SE.AC.assumptions()) {
     if (!AssumeVH)
       continue;
     auto *AssumeI = cast<CallInst>(AssumeVH);
-    if (!DT.dominates(AssumeI, Header))
+    if (!SE.DT.dominates(AssumeI, Header))
       continue;
     Terms.emplace_back(AssumeI->getOperand(0), true);
   }
 
   // Second, collect information from llvm.experimental.guards dominating the loop.
-  auto *GuardDecl = F.getParent()->getFunction(
+  auto *GuardDecl = SE.F.getParent()->getFunction(
       Intrinsic::getName(Intrinsic::experimental_guard));
   if (GuardDecl)
     for (const auto *GU : GuardDecl->users())
       if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
-        if (Guard->getFunction() == Header->getParent() && DT.dominates(Guard, Header))
+        if (Guard->getFunction() == Header->getParent() &&
+            SE.DT.dominates(Guard, Header))
           Terms.emplace_back(Guard->getArgOperand(0), true);
 
   // Third, collect conditions from dominating branches. Starting at the loop
@@ -15564,7 +15457,8 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
   // TODO: share this logic with isLoopEntryGuardedByCond.
   for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
            L->getLoopPredecessor(), Header);
-       Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
+       Pair.first;
+       Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
 
     const BranchInst *LoopEntryPredicate =
         dyn_cast<BranchInst>(Pair.first->getTerminator());
@@ -15575,11 +15469,10 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
                        LoopEntryPredicate->getSuccessor(0) == Pair.second);
   }
 
-  // Now apply the information from the collected conditions to RewriteMap.
-  // Conditions are processed in reverse order, so the earliest conditions is
-  // processed first. This ensures the SCEVs with the shortest dependency chains
-  // are constructed first.
-  DenseMap<const SCEV *, const SCEV *> RewriteMap;
+  // Now apply the information from the collected conditions to
+  // Guards.RewriteMap. Conditions are processed in reverse order, so the
+  // earliest conditions is processed first. This ensures the SCEVs with the
+  // shortest dependency chains are constructed first.
   for (auto [Term, EnterIfTrue] : reverse(Terms)) {
     SmallVector<Value *, 8> Worklist;
     SmallPtrSet<Value *, 8> Visited;
@@ -15592,9 +15485,9 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
       if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
         auto Predicate =
             EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
-        const auto *LHS = getSCEV(Cmp->getOperand(0));
-        const auto *RHS = getSCEV(Cmp->getOperand(1));
-        CollectCondition(Predicate, LHS, RHS, RewriteMap);
+        const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
+        const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
+        CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
         continue;
       }
 
@@ -15610,12 +15503,14 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
   // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
   // the replacement expressions are contained in the ranges of the replaced
   // expressions.
-  bool PreserveNUW = true;
-  bool PreserveNSW = true;
+  Guards.PreserveNUW = true;
+  Guards.PreserveNSW = true;
   for (const SCEV *Expr : ExprsToRewrite) {
-    const SCEV *RewriteTo = RewriteMap[Expr];
-    PreserveNUW &= getUnsignedRange(Expr).contains(getUnsignedRange(RewriteTo));
-    PreserveNSW &= getSignedRange(Expr).contains(getSignedRange(RewriteTo));
+    const SCEV *RewriteTo = Guards.RewriteMap[Expr];
+    Guards.PreserveNUW &=
+        SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
+    Guards.PreserveNSW &=
+        SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
   }
 
   // Now that all rewrite information is collect, rewrite the collected
@@ -15623,29 +15518,134 @@ ScalarEvolution::collectRewriteInfoFromLoopGuards(const Loop *L) {
   // sub-expressions.
   if (ExprsToRewrite.size() > 1) {
     for (const SCEV *Expr : ExprsToRewrite) {
-      const SCEV *RewriteTo = RewriteMap[Expr];
-      RewriteMap.erase(Expr);
-      SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW,
-                                     PreserveNSW);
-      RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
+      const SCEV *RewriteTo = Guards.RewriteMap[Expr];
+      Guards.RewriteMap.erase(Expr);
+      Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
     }
   }
-  return {RewriteMap, PreserveNUW, PreserveNSW};
+  return Guards;
 }
 
-const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
-  const auto &[RewriteMap, PreserveNUW, PreserveNSW] =
-      collectRewriteInfoFromLoopGuards(L);
-  return applyLoopGuards(Expr, L, RewriteMap, PreserveNUW, PreserveNSW);
-}
+const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
+  /// A rewriter to replace SCEV expressions in Map with the corresponding entry
+  /// in the map. It skips AddRecExpr because we cannot guarantee that the
+  /// replacement is loop invariant in the loop of the AddRec.
+  class SCEVLoopGuardRewriter
+      : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
+    const DenseMap<const SCEV *, const SCEV *> ⤅
+
+    SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
+
+  public:
+    SCEVLoopGuardRewriter(ScalarEvolution &SE,
+                          const ScalarEvolution::LoopGuards &Guards)
+        : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
+      if (Guards.PreserveNUW)
+        FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
+      if (Guards.PreserveNSW)
+        FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
+    }
+
+    const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
+
+    const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+      auto I = Map.find(Expr);
+      if (I == Map.end())
+        return Expr;
+      return I->second;
+    }
+
+    const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
+      auto I = Map.find(Expr);
+      if (I == Map.end()) {
+        // If we didn't find the extact ZExt expr in the map, check if there's
+        // an entry for a smaller ZExt we can use instead.
+        Type *Ty = Expr->getType();
+        const SCEV *Op = Expr->getOperand(0);
+        unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
+        while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
+               Bitwidth > Op->getType()->getScalarSizeInBits()) {
+          Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
+          auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
+          auto I = Map.find(NarrowExt);
+          if (I != Map.end())
+            return SE.getZeroExtendExpr(I->second, Ty);
+          Bitwidth = Bitwidth / 2;
+        }
+
+        return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
+            Expr);
+      }
+      return I->second;
+    }
+
+    const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
+      auto I = Map.find(Expr);
+      if (I == Map.end())
+        return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSignExtendExpr(
+            Expr);
+      return I->second;
+    }
+
+    const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
+      auto I = Map.find(Expr);
+      if (I == Map.end())
+        return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitUMinExpr(Expr);
+      return I->second;
+    }
+
+    const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
+      auto I = Map.find(Expr);
+      if (I == Map.end())
+        return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
+      return I->second;
+    }
+
+    const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
+      SmallVector<const SCEV *, 2> Operands;
+      bool Changed = false;
+      for (const auto *Op : Expr->operands()) {
+        Operands.push_back(
+            SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
+        Changed |= Op != Operands.back();
+      }
+      // We are only replacing operands with equivalent values, so transfer the
+      // flags from the original expression.
+      return !Changed ? Expr
+                      : SE.getAddExpr(Operands,
+                                      ScalarEvolution::maskFlags(
+                                          Expr->getNoWrapFlags(), FlagMask));
+    }
+
+    const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
+      SmallVector<const SCEV *, 2> Operands;
+      bool Changed = false;
+      for (const auto *Op : Expr->operands()) {
+        Operands.push_back(
+            SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
+        Changed |= Op != Operands.back();
+      }
+      // We are only replacing operands with equivalent values, so transfer the
+      // flags from the original expression.
+      return !Changed ? Expr
+                      : SE.getMulExpr(Operands,
+                                      ScalarEvolution::maskFlags(
+                                          Expr->getNoWrapFlags(), FlagMask));
+    }
+  };
 
-const SCEV *ScalarEvolution::applyLoopGuards(
-    const SCEV *Expr, const Loop *L,
-    const DenseMap<const SCEV *, const SCEV *> &RewriteMap, bool PreserveNUW,
-    bool PreserveNSW) {
   if (RewriteMap.empty())
     return Expr;
 
-  SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW, PreserveNSW);
+  SCEVLoopGuardRewriter Rewriter(SE, *this);
   return Rewriter.visit(Expr);
 }
+
+const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
+  return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
+}
+
+const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr,
+                                             const LoopGuards &Guards) {
+  return Guards.rewrite(Expr);
+}



More information about the llvm-commits mailing list