[llvm] 2f89d4a - [SCEV] Split collecting and applying rewrite info from loop guards (NFC) (#97316)

via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 2 11:32:32 PDT 2024


Author: Florian Hahn
Date: 2024-07-02T19:32:28+01:00
New Revision: 2f89d4a8c79a2e88f2749c7460886e0d776f3aff

URL: https://github.com/llvm/llvm-project/commit/2f89d4a8c79a2e88f2749c7460886e0d776f3aff
DIFF: https://github.com/llvm/llvm-project/commit/2f89d4a8c79a2e88f2749c7460886e0d776f3aff.diff

LOG: [SCEV] Split collecting and applying rewrite info from loop guards (NFC) (#97316)


Introduce a new LoopGuards class to track info from loop guards and split 
off collecting rewrite info to LoopGuards::collect. This allows users of 
applyLoopGuards to collect rewrite info once in cases where the same
loop guards are applied multiple times.

This is used to collect rewrite info once in howFarToZero, which saves a
bit of compile-time:
stage1-O3: -0.04%
stage1-ReleaseThinLTO: -0.02%
stage1-ReleaseLTO-g: -0.04%
stage2-O3: -0.02%

https://llvm-compile-time-tracker.com/compare.php?from=117b53ae38428ca66eaa886fb432e6f09db88fe4&to=4ffb7b2e1c99081ccebe6f236c48a0be2f64b6ff&stat=instructions:u

Notably this improves mafft by -0.9% with -O3, -0.11% with LTO and
-0.12% with stage2-O3.

PR: https://github.com/llvm/llvm-project/pull/97316

Added: 
    

Modified: 
    llvm/include/llvm/Analysis/ScalarEvolution.h
    llvm/lib/Analysis/ScalarEvolution.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Analysis/ScalarEvolution.h b/llvm/include/llvm/Analysis/ScalarEvolution.h
index 97b30daf4427a..d9bfca763819f 100644
--- a/llvm/include/llvm/Analysis/ScalarEvolution.h
+++ b/llvm/include/llvm/Analysis/ScalarEvolution.h
@@ -1299,8 +1299,26 @@ class ScalarEvolution {
   /// sharpen it.
   void setNoWrapFlags(SCEVAddRecExpr *AddRec, SCEV::NoWrapFlags Flags);
 
+  class LoopGuards {
+    DenseMap<const SCEV *, const SCEV *> RewriteMap;
+    bool PreserveNUW = false;
+    bool PreserveNSW = false;
+    ScalarEvolution &SE;
+
+    LoopGuards(ScalarEvolution &SE) : SE(SE) {}
+
+  public:
+    /// Collect rewrite map for loop guards for loop \p L, together with flags
+    /// indicating if NUW and NSW can be preserved during rewriting.
+    static LoopGuards collect(const Loop *L, ScalarEvolution &SE);
+
+    /// Try to apply the collected loop guards to \p Expr.
+    const SCEV *rewrite(const SCEV *Expr) const;
+  };
+
   /// Try to apply information from loop guards for \p L to \p Expr.
   const SCEV *applyLoopGuards(const SCEV *Expr, const Loop *L);
+  const SCEV *applyLoopGuards(const SCEV *Expr, const LoopGuards &Guards);
 
   /// Return true if the loop has no abnormal exits. That is, if the loop
   /// is not infinite, it must exit through an explicit edge in the CFG.

diff  --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index e998fe9452ad7..430e1c6d8f8c6 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -10490,8 +10490,9 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   if (!isLoopInvariant(Step, L))
     return getCouldNotCompute();
 
+  LoopGuards Guards = LoopGuards::collect(L, *this);
   // Specialize step for this loop so we get context sensitive facts below.
-  const SCEV *StepWLG = applyLoopGuards(Step, L);
+  const SCEV *StepWLG = applyLoopGuards(Step, Guards);
 
   // For positive steps (counting up until unsigned overflow):
   //   N = -Start/Step (as unsigned)
@@ -10508,7 +10509,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
   //   N = Distance (as unsigned)
   if (StepC &&
       (StepC->getValue()->isOne() || StepC->getValue()->isMinusOne())) {
-    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, L));
+    APInt MaxBECount = getUnsignedRangeMax(applyLoopGuards(Distance, Guards));
     MaxBECount = APIntOps::umin(MaxBECount, getUnsignedRangeMax(Distance));
 
     // When a loop like "for (int i = 0; i != n; ++i) { /* body */ }" is rotated,
@@ -10549,7 +10550,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
         getUDivExpr(Distance, CountDown ? getNegativeSCEV(Step) : Step);
     const SCEV *ConstantMax = getCouldNotCompute();
     if (Exact != getCouldNotCompute()) {
-      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, L));
+      APInt MaxInt = getUnsignedRangeMax(applyLoopGuards(Exact, Guards));
       ConstantMax =
           getConstant(APIntOps::umin(MaxInt, getUnsignedRangeMax(Exact)));
     }
@@ -10566,7 +10567,7 @@ ScalarEvolution::ExitLimit ScalarEvolution::howFarToZero(const SCEV *V,
 
   const SCEV *M = E;
   if (E != getCouldNotCompute()) {
-    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, L));
+    APInt MaxWithGuards = getUnsignedRangeMax(applyLoopGuards(E, Guards));
     M = getConstant(APIntOps::umin(MaxWithGuards, getUnsignedRangeMax(E)));
   }
   auto *S = isa<SCEVCouldNotCompute>(E) ? M : E;
@@ -15086,112 +15087,9 @@ bool ScalarEvolution::matchURem(const SCEV *Expr, const SCEV *&LHS,
   return false;
 }
 
-/// A rewriter to replace SCEV expressions in Map with the corresponding entry
-/// in the map. It skips AddRecExpr because we cannot guarantee that the
-/// replacement is loop invariant in the loop of the AddRec.
-class SCEVLoopGuardRewriter : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
-  const DenseMap<const SCEV *, const SCEV *> ⤅
-
-  SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
-
-public:
-  SCEVLoopGuardRewriter(ScalarEvolution &SE,
-                        DenseMap<const SCEV *, const SCEV *> &M,
-                        bool PreserveNUW, bool PreserveNSW)
-      : SCEVRewriteVisitor(SE), Map(M) {
-    if (PreserveNUW)
-      FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
-    if (PreserveNSW)
-      FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
-  }
-
-  const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
-
-  const SCEV *visitUnknown(const SCEVUnknown *Expr) {
-    auto I = Map.find(Expr);
-    if (I == Map.end())
-      return Expr;
-    return I->second;
-  }
-
-  const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
-    auto I = Map.find(Expr);
-    if (I == Map.end()) {
-      // If we didn't find the extact ZExt expr in the map, check if there's an
-      // entry for a smaller ZExt we can use instead.
-      Type *Ty = Expr->getType();
-      const SCEV *Op = Expr->getOperand(0);
-      unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
-      while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
-             Bitwidth > Op->getType()->getScalarSizeInBits()) {
-        Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
-        auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
-        auto I = Map.find(NarrowExt);
-        if (I != Map.end())
-          return SE.getZeroExtendExpr(I->second, Ty);
-        Bitwidth = Bitwidth / 2;
-      }
-
-      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
-          Expr);
-    }
-    return I->second;
-  }
-
-  const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
-    auto I = Map.find(Expr);
-    if (I == Map.end())
-      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSignExtendExpr(
-          Expr);
-    return I->second;
-  }
-
-  const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
-    auto I = Map.find(Expr);
-    if (I == Map.end())
-      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitUMinExpr(Expr);
-    return I->second;
-  }
-
-  const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
-    auto I = Map.find(Expr);
-    if (I == Map.end())
-      return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
-    return I->second;
-  }
-
-  const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
-    SmallVector<const SCEV *, 2> Operands;
-    bool Changed = false;
-    for (const auto *Op : Expr->operands()) {
-      Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
-      Changed |= Op != Operands.back();
-    }
-    // We are only replacing operands with equivalent values, so transfer the
-    // flags from the original expression.
-    return !Changed
-               ? Expr
-               : SE.getAddExpr(Operands, ScalarEvolution::maskFlags(
-                                             Expr->getNoWrapFlags(), FlagMask));
-  }
-
-  const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
-    SmallVector<const SCEV *, 2> Operands;
-    bool Changed = false;
-    for (const auto *Op : Expr->operands()) {
-      Operands.push_back(SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
-      Changed |= Op != Operands.back();
-    }
-    // We are only replacing operands with equivalent values, so transfer the
-    // flags from the original expression.
-    return !Changed
-               ? Expr
-               : SE.getMulExpr(Operands, ScalarEvolution::maskFlags(
-                                             Expr->getNoWrapFlags(), FlagMask));
-  }
-};
-
-const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
+ScalarEvolution::LoopGuards
+ScalarEvolution::LoopGuards::collect(const Loop *L, ScalarEvolution &SE) {
+  LoopGuards Guards(SE);
   SmallVector<const SCEV *> ExprsToRewrite;
   auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
                               const SCEV *RHS,
@@ -15211,7 +15109,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
     // Check for a condition of the form (-C1 + X < C2).  InstCombine will
     // create this form when combining two checks of the form (X u< C2 + C1) and
     // (X >=u C1).
-    auto MatchRangeCheckIdiom = [this, Predicate, LHS, RHS, &RewriteMap,
+    auto MatchRangeCheckIdiom = [&SE, Predicate, LHS, RHS, &RewriteMap,
                                  &ExprsToRewrite]() {
       auto *AddExpr = dyn_cast<SCEVAddExpr>(LHS);
       if (!AddExpr || AddExpr->getNumOperands() != 2)
@@ -15232,9 +15130,10 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
         return false;
       auto I = RewriteMap.find(LHSUnknown);
       const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : LHSUnknown;
-      RewriteMap[LHSUnknown] = getUMaxExpr(
-          getConstant(ExactRegion.getUnsignedMin()),
-          getUMinExpr(RewrittenLHS, getConstant(ExactRegion.getUnsignedMax())));
+      RewriteMap[LHSUnknown] = SE.getUMaxExpr(
+          SE.getConstant(ExactRegion.getUnsignedMin()),
+          SE.getUMinExpr(RewrittenLHS,
+                         SE.getConstant(ExactRegion.getUnsignedMax())));
       ExprsToRewrite.push_back(LHSUnknown);
       return true;
     };
@@ -15287,7 +15186,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       APInt Rem = ExprVal.urem(DivisorVal);
       if (!Rem.isZero())
         // return the SCEV: Expr + Divisor - Expr % Divisor
-        return getConstant(ExprVal + DivisorVal - Rem);
+        return SE.getConstant(ExprVal + DivisorVal - Rem);
       return Expr;
     };
 
@@ -15302,7 +15201,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
         return Expr;
       APInt Rem = ExprVal.urem(DivisorVal);
       // return the SCEV: Expr - Expr % Divisor
-      return getConstant(ExprVal - Rem);
+      return SE.getConstant(ExprVal - Rem);
     };
 
     // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
@@ -15318,14 +15217,14 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
             return MinMaxExpr;
           auto IsMin =
               isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
-          assert(isKnownNonNegative(MinMaxLHS) &&
+          assert(SE.isKnownNonNegative(MinMaxLHS) &&
                  "Expected non-negative operand!");
           auto *DivisibleExpr =
               IsMin ? GetPreviousSCEVDividesByDivisor(MinMaxLHS, Divisor)
                     : GetNextSCEVDividesByDivisor(MinMaxLHS, Divisor);
           SmallVector<const SCEV *> Ops = {
               ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
-          return getMinMaxExpr(SCTy, Ops);
+          return SE.getMinMaxExpr(SCTy, Ops);
         };
 
     // If we have LHS == 0, check if LHS is computing a property of some unknown
@@ -15337,14 +15236,14 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       // explicitly express that.
       const SCEV *URemLHS = nullptr;
       const SCEV *URemRHS = nullptr;
-      if (matchURem(LHS, URemLHS, URemRHS)) {
+      if (SE.matchURem(LHS, URemLHS, URemRHS)) {
         if (const SCEVUnknown *LHSUnknown = dyn_cast<SCEVUnknown>(URemLHS)) {
           auto I = RewriteMap.find(LHSUnknown);
           const SCEV *RewrittenLHS =
               I != RewriteMap.end() ? I->second : LHSUnknown;
           RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
           const auto *Multiple =
-              getMulExpr(getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
+              SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
           RewriteMap[LHSUnknown] = Multiple;
           ExprsToRewrite.push_back(LHSUnknown);
           return;
@@ -15353,7 +15252,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
     }
 
     // Do not apply information for constants or if RHS contains an AddRec.
-    if (isa<SCEVConstant>(LHS) || containsAddRecurrence(RHS))
+    if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
       return;
 
     // If RHS is SCEVUnknown, make sure the information is applied to it.
@@ -15412,7 +15311,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
     // Return true if Expr known to divide by \p DividesBy.
     std::function<bool(const SCEV *, const SCEV *&)> IsKnownToDivideBy =
         [&](const SCEV *Expr, const SCEV *DividesBy) {
-          if (getURemExpr(Expr, DividesBy)->isZero())
+          if (SE.getURemExpr(Expr, DividesBy)->isZero())
             return true;
           if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr))
             return IsKnownToDivideBy(MinMax->getOperand(0), DividesBy) &&
@@ -15438,21 +15337,21 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
     // We cannot express strict predicates in SCEV, so instead we replace them
     // with non-strict ones against plus or minus one of RHS depending on the
     // predicate.
-    const SCEV *One = getOne(RHS->getType());
+    const SCEV *One = SE.getOne(RHS->getType());
     switch (Predicate) {
       case CmpInst::ICMP_ULT:
         if (RHS->getType()->isPointerTy())
           return;
-        RHS = getUMaxExpr(RHS, One);
+        RHS = SE.getUMaxExpr(RHS, One);
         [[fallthrough]];
       case CmpInst::ICMP_SLT: {
-        RHS = getMinusSCEV(RHS, One);
+        RHS = SE.getMinusSCEV(RHS, One);
         RHS = DividesBy ? GetPreviousSCEVDividesByDivisor(RHS, DividesBy) : RHS;
         break;
       }
       case CmpInst::ICMP_UGT:
       case CmpInst::ICMP_SGT:
-        RHS = getAddExpr(RHS, One);
+        RHS = SE.getAddExpr(RHS, One);
         RHS = DividesBy ? GetNextSCEVDividesByDivisor(RHS, DividesBy) : RHS;
         break;
       case CmpInst::ICMP_ULE:
@@ -15486,25 +15385,25 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       switch (Predicate) {
       case CmpInst::ICMP_ULT:
       case CmpInst::ICMP_ULE:
-        To = getUMinExpr(FromRewritten, RHS);
+        To = SE.getUMinExpr(FromRewritten, RHS);
         if (auto *UMax = dyn_cast<SCEVUMaxExpr>(FromRewritten))
           EnqueueOperands(UMax);
         break;
       case CmpInst::ICMP_SLT:
       case CmpInst::ICMP_SLE:
-        To = getSMinExpr(FromRewritten, RHS);
+        To = SE.getSMinExpr(FromRewritten, RHS);
         if (auto *SMax = dyn_cast<SCEVSMaxExpr>(FromRewritten))
           EnqueueOperands(SMax);
         break;
       case CmpInst::ICMP_UGT:
       case CmpInst::ICMP_UGE:
-        To = getUMaxExpr(FromRewritten, RHS);
+        To = SE.getUMaxExpr(FromRewritten, RHS);
         if (auto *UMin = dyn_cast<SCEVUMinExpr>(FromRewritten))
           EnqueueOperands(UMin);
         break;
       case CmpInst::ICMP_SGT:
       case CmpInst::ICMP_SGE:
-        To = getSMaxExpr(FromRewritten, RHS);
+        To = SE.getSMaxExpr(FromRewritten, RHS);
         if (auto *SMin = dyn_cast<SCEVSMinExpr>(FromRewritten))
           EnqueueOperands(SMin);
         break;
@@ -15517,7 +15416,7 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
             cast<SCEVConstant>(RHS)->getValue()->isNullValue()) {
           const SCEV *OneAlignedUp =
               DividesBy ? GetNextSCEVDividesByDivisor(One, DividesBy) : One;
-          To = getUMaxExpr(FromRewritten, OneAlignedUp);
+          To = SE.getUMaxExpr(FromRewritten, OneAlignedUp);
         }
         break;
       default:
@@ -15532,22 +15431,23 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
   BasicBlock *Header = L->getHeader();
   SmallVector<PointerIntPair<Value *, 1, bool>> Terms;
   // First, collect information from assumptions dominating the loop.
-  for (auto &AssumeVH : AC.assumptions()) {
+  for (auto &AssumeVH : SE.AC.assumptions()) {
     if (!AssumeVH)
       continue;
     auto *AssumeI = cast<CallInst>(AssumeVH);
-    if (!DT.dominates(AssumeI, Header))
+    if (!SE.DT.dominates(AssumeI, Header))
       continue;
     Terms.emplace_back(AssumeI->getOperand(0), true);
   }
 
   // Second, collect information from llvm.experimental.guards dominating the loop.
-  auto *GuardDecl = F.getParent()->getFunction(
+  auto *GuardDecl = SE.F.getParent()->getFunction(
       Intrinsic::getName(Intrinsic::experimental_guard));
   if (GuardDecl)
     for (const auto *GU : GuardDecl->users())
       if (const auto *Guard = dyn_cast<IntrinsicInst>(GU))
-        if (Guard->getFunction() == Header->getParent() && DT.dominates(Guard, Header))
+        if (Guard->getFunction() == Header->getParent() &&
+            SE.DT.dominates(Guard, Header))
           Terms.emplace_back(Guard->getArgOperand(0), true);
 
   // Third, collect conditions from dominating branches. Starting at the loop
@@ -15557,7 +15457,8 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
   // TODO: share this logic with isLoopEntryGuardedByCond.
   for (std::pair<const BasicBlock *, const BasicBlock *> Pair(
            L->getLoopPredecessor(), Header);
-       Pair.first; Pair = getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
+       Pair.first;
+       Pair = SE.getPredecessorWithUniqueSuccessorForBB(Pair.first)) {
 
     const BranchInst *LoopEntryPredicate =
         dyn_cast<BranchInst>(Pair.first->getTerminator());
@@ -15568,11 +15469,10 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
                        LoopEntryPredicate->getSuccessor(0) == Pair.second);
   }
 
-  // Now apply the information from the collected conditions to RewriteMap.
-  // Conditions are processed in reverse order, so the earliest conditions is
-  // processed first. This ensures the SCEVs with the shortest dependency chains
-  // are constructed first.
-  DenseMap<const SCEV *, const SCEV *> RewriteMap;
+  // Now apply the information from the collected conditions to
+  // Guards.RewriteMap. Conditions are processed in reverse order, so the
+  // earliest conditions is processed first. This ensures the SCEVs with the
+  // shortest dependency chains are constructed first.
   for (auto [Term, EnterIfTrue] : reverse(Terms)) {
     SmallVector<Value *, 8> Worklist;
     SmallPtrSet<Value *, 8> Visited;
@@ -15585,9 +15485,9 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
       if (auto *Cmp = dyn_cast<ICmpInst>(Cond)) {
         auto Predicate =
             EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
-        const auto *LHS = getSCEV(Cmp->getOperand(0));
-        const auto *RHS = getSCEV(Cmp->getOperand(1));
-        CollectCondition(Predicate, LHS, RHS, RewriteMap);
+        const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
+        const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
+        CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
         continue;
       }
 
@@ -15600,18 +15500,17 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
     }
   }
 
-  if (RewriteMap.empty())
-    return Expr;
-
   // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
   // the replacement expressions are contained in the ranges of the replaced
   // expressions.
-  bool PreserveNUW = true;
-  bool PreserveNSW = true;
+  Guards.PreserveNUW = true;
+  Guards.PreserveNSW = true;
   for (const SCEV *Expr : ExprsToRewrite) {
-    const SCEV *RewriteTo = RewriteMap[Expr];
-    PreserveNUW &= getUnsignedRange(Expr).contains(getUnsignedRange(RewriteTo));
-    PreserveNSW &= getSignedRange(Expr).contains(getSignedRange(RewriteTo));
+    const SCEV *RewriteTo = Guards.RewriteMap[Expr];
+    Guards.PreserveNUW &=
+        SE.getUnsignedRange(Expr).contains(SE.getUnsignedRange(RewriteTo));
+    Guards.PreserveNSW &=
+        SE.getSignedRange(Expr).contains(SE.getSignedRange(RewriteTo));
   }
 
   // Now that all rewrite information is collect, rewrite the collected
@@ -15619,13 +15518,134 @@ const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
   // sub-expressions.
   if (ExprsToRewrite.size() > 1) {
     for (const SCEV *Expr : ExprsToRewrite) {
-      const SCEV *RewriteTo = RewriteMap[Expr];
-      RewriteMap.erase(Expr);
-      SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW,
-                                     PreserveNSW);
-      RewriteMap.insert({Expr, Rewriter.visit(RewriteTo)});
+      const SCEV *RewriteTo = Guards.RewriteMap[Expr];
+      Guards.RewriteMap.erase(Expr);
+      Guards.RewriteMap.insert({Expr, Guards.rewrite(RewriteTo)});
     }
   }
-  SCEVLoopGuardRewriter Rewriter(*this, RewriteMap, PreserveNUW, PreserveNSW);
+  return Guards;
+}
+
+const SCEV *ScalarEvolution::LoopGuards::rewrite(const SCEV *Expr) const {
+  /// A rewriter to replace SCEV expressions in Map with the corresponding entry
+  /// in the map. It skips AddRecExpr because we cannot guarantee that the
+  /// replacement is loop invariant in the loop of the AddRec.
+  class SCEVLoopGuardRewriter
+      : public SCEVRewriteVisitor<SCEVLoopGuardRewriter> {
+    const DenseMap<const SCEV *, const SCEV *> ⤅
+
+    SCEV::NoWrapFlags FlagMask = SCEV::FlagAnyWrap;
+
+  public:
+    SCEVLoopGuardRewriter(ScalarEvolution &SE,
+                          const ScalarEvolution::LoopGuards &Guards)
+        : SCEVRewriteVisitor(SE), Map(Guards.RewriteMap) {
+      if (Guards.PreserveNUW)
+        FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNUW);
+      if (Guards.PreserveNSW)
+        FlagMask = ScalarEvolution::setFlags(FlagMask, SCEV::FlagNSW);
+    }
+
+    const SCEV *visitAddRecExpr(const SCEVAddRecExpr *Expr) { return Expr; }
+
+    const SCEV *visitUnknown(const SCEVUnknown *Expr) {
+      auto I = Map.find(Expr);
+      if (I == Map.end())
+        return Expr;
+      return I->second;
+    }
+
+    const SCEV *visitZeroExtendExpr(const SCEVZeroExtendExpr *Expr) {
+      auto I = Map.find(Expr);
+      if (I == Map.end()) {
+        // If we didn't find the extact ZExt expr in the map, check if there's
+        // an entry for a smaller ZExt we can use instead.
+        Type *Ty = Expr->getType();
+        const SCEV *Op = Expr->getOperand(0);
+        unsigned Bitwidth = Ty->getScalarSizeInBits() / 2;
+        while (Bitwidth % 8 == 0 && Bitwidth >= 8 &&
+               Bitwidth > Op->getType()->getScalarSizeInBits()) {
+          Type *NarrowTy = IntegerType::get(SE.getContext(), Bitwidth);
+          auto *NarrowExt = SE.getZeroExtendExpr(Op, NarrowTy);
+          auto I = Map.find(NarrowExt);
+          if (I != Map.end())
+            return SE.getZeroExtendExpr(I->second, Ty);
+          Bitwidth = Bitwidth / 2;
+        }
+
+        return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitZeroExtendExpr(
+            Expr);
+      }
+      return I->second;
+    }
+
+    const SCEV *visitSignExtendExpr(const SCEVSignExtendExpr *Expr) {
+      auto I = Map.find(Expr);
+      if (I == Map.end())
+        return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSignExtendExpr(
+            Expr);
+      return I->second;
+    }
+
+    const SCEV *visitUMinExpr(const SCEVUMinExpr *Expr) {
+      auto I = Map.find(Expr);
+      if (I == Map.end())
+        return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitUMinExpr(Expr);
+      return I->second;
+    }
+
+    const SCEV *visitSMinExpr(const SCEVSMinExpr *Expr) {
+      auto I = Map.find(Expr);
+      if (I == Map.end())
+        return SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visitSMinExpr(Expr);
+      return I->second;
+    }
+
+    const SCEV *visitAddExpr(const SCEVAddExpr *Expr) {
+      SmallVector<const SCEV *, 2> Operands;
+      bool Changed = false;
+      for (const auto *Op : Expr->operands()) {
+        Operands.push_back(
+            SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
+        Changed |= Op != Operands.back();
+      }
+      // We are only replacing operands with equivalent values, so transfer the
+      // flags from the original expression.
+      return !Changed ? Expr
+                      : SE.getAddExpr(Operands,
+                                      ScalarEvolution::maskFlags(
+                                          Expr->getNoWrapFlags(), FlagMask));
+    }
+
+    const SCEV *visitMulExpr(const SCEVMulExpr *Expr) {
+      SmallVector<const SCEV *, 2> Operands;
+      bool Changed = false;
+      for (const auto *Op : Expr->operands()) {
+        Operands.push_back(
+            SCEVRewriteVisitor<SCEVLoopGuardRewriter>::visit(Op));
+        Changed |= Op != Operands.back();
+      }
+      // We are only replacing operands with equivalent values, so transfer the
+      // flags from the original expression.
+      return !Changed ? Expr
+                      : SE.getMulExpr(Operands,
+                                      ScalarEvolution::maskFlags(
+                                          Expr->getNoWrapFlags(), FlagMask));
+    }
+  };
+
+  if (RewriteMap.empty())
+    return Expr;
+
+  SCEVLoopGuardRewriter Rewriter(SE, *this);
   return Rewriter.visit(Expr);
 }
+
+const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr, const Loop *L) {
+  return applyLoopGuards(Expr, LoopGuards::collect(L, *this));
+}
+
+const SCEV *ScalarEvolution::applyLoopGuards(const SCEV *Expr,
+                                             const LoopGuards &Guards) {
+  return Guards.rewrite(Expr);
+}


        


More information about the llvm-commits mailing list