[llvm] r315713 - [SCEV] Maintain and use a loop->loop invalidation dependency

Sanjoy Das via llvm-commits llvm-commits at lists.llvm.org
Fri Oct 13 10:13:44 PDT 2017


Author: sanjoy
Date: Fri Oct 13 10:13:44 2017
New Revision: 315713

URL: http://llvm.org/viewvc/llvm-project?rev=315713&view=rev
Log:
[SCEV] Maintain and use a loop->loop invalidation dependency

Summary:
This change uses the loop use list added in the previous change to remember the
loops that appear in the trip count expressions of other loops; and uses it in
forgetLoop.  This lets us not scan every loop in the function on a forgetLoop
call.

With this change we no longer invalidate clear out backedge taken counts on
forgetValue.  I think this is fine -- the contract is that SCEV users must call
forgetLoop(L) if their change to the IR could have changed the trip count of L;
solely calling forgetValue on a value feeding into the backedge condition of L
is not enough.  Moreover, I don't think we can strengthen forgetValue to be
sufficient for invalidating trip counts without significantly re-architecting
SCEV.  For instance, if we have the loop:

  I = *Ptr;
  E = I + 10;
  do {
    // ...
  } while (++I != E);

then the backedge taken count of the loop is 9, and it has no reference to
either I or E, i.e. there is no way in SCEV today to re-discover the dependency
of the loop's trip count on E or I.  So a SCEV client cannot change E to (say)
"I + 20", call forgetValue(E) and expect the loop's trip count to be updated.

Reviewers: atrick, sunfish, mkazantsev

Subscribers: mcrosier, llvm-commits

Differential Revision: https://reviews.llvm.org/D38435

Modified:
    llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
    llvm/trunk/lib/Analysis/ScalarEvolution.cpp
    llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp

Modified: llvm/trunk/include/llvm/Analysis/ScalarEvolution.h
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/include/llvm/Analysis/ScalarEvolution.h?rev=315713&r1=315712&r2=315713&view=diff
==============================================================================
--- llvm/trunk/include/llvm/Analysis/ScalarEvolution.h (original)
+++ llvm/trunk/include/llvm/Analysis/ScalarEvolution.h Fri Oct 13 10:13:44 2017
@@ -29,6 +29,7 @@
 #include "llvm/ADT/Hashing.h"
 #include "llvm/ADT/Optional.h"
 #include "llvm/ADT/PointerIntPair.h"
+#include "llvm/ADT/PointerUnion.h"
 #include "llvm/ADT/SetVector.h"
 #include "llvm/ADT/SmallPtrSet.h"
 #include "llvm/ADT/SmallVector.h"
@@ -1262,6 +1263,10 @@ private:
 
     /// Invalidate this result and free associated memory.
     void clear();
+
+    /// Insert all loops referred to by this BackedgeTakenCount into \p Result.
+    void findUsedLoops(ScalarEvolution &SE,
+                       SmallPtrSetImpl<const Loop *> &Result) const;
   };
 
   /// Cache the backedge-taken count of the loops for this function as they
@@ -1770,14 +1775,20 @@ private:
   /// Find all of the loops transitively used in \p S, and update \c LoopUsers
   /// accordingly.
   void addToLoopUseLists(const SCEV *S);
+  void addToLoopUseLists(const BackedgeTakenInfo &BTI, const Loop *L);
 
   FoldingSet<SCEV> UniqueSCEVs;
   FoldingSet<SCEVPredicate> UniquePreds;
   BumpPtrAllocator SCEVAllocator;
 
-  /// This maps loops to a list of SCEV expressions that (transitively) use said
-  /// loop.
-  DenseMap<const Loop *, SmallVector<const SCEV *, 4>> LoopUsers;
+  /// This maps loops to a list of entities that (transitively) use said loop.
+  /// A SCEV expression in the vector corresponding to a loop denotes that the
+  /// SCEV expression transitively uses said loop.  A loop (LA) in the vector
+  /// corresponding to another loop (LB) denotes that LB is used in one of the
+  /// cached trip counts for LA.
+  DenseMap<const Loop *,
+           SmallVector<PointerUnion<const SCEV *, const Loop *>, 4>>
+      LoopUsers;
 
   /// Cache tentative mappings from UnknownSCEVs in a Loop, to a SCEV expression
   /// they can be rewritten into under certain predicates.

Modified: llvm/trunk/lib/Analysis/ScalarEvolution.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Analysis/ScalarEvolution.cpp?rev=315713&r1=315712&r2=315713&view=diff
==============================================================================
--- llvm/trunk/lib/Analysis/ScalarEvolution.cpp (original)
+++ llvm/trunk/lib/Analysis/ScalarEvolution.cpp Fri Oct 13 10:13:44 2017
@@ -6293,6 +6293,7 @@ ScalarEvolution::getPredicatedBackedgeTa
   BackedgeTakenInfo Result =
       computeBackedgeTakenCount(L, /*AllowPredicates=*/true);
 
+  addToLoopUseLists(Result, L);
   return PredicatedBackedgeTakenCounts.find(L)->second = std::move(Result);
 }
 
@@ -6368,6 +6369,7 @@ ScalarEvolution::getBackedgeTakenInfo(co
   // recusive call to getBackedgeTakenInfo (on a different
   // loop), which would invalidate the iterator computed
   // earlier.
+  addToLoopUseLists(Result, L);
   return BackedgeTakenCounts.find(L)->second = std::move(Result);
 }
 
@@ -6405,8 +6407,14 @@ void ScalarEvolution::forgetLoop(const L
 
     auto LoopUsersItr = LoopUsers.find(CurrL);
     if (LoopUsersItr != LoopUsers.end()) {
-      for (auto *S : LoopUsersItr->second)
-        forgetMemoizedResults(S);
+      for (auto LoopOrSCEV : LoopUsersItr->second) {
+        if (auto *S = LoopOrSCEV.dyn_cast<const SCEV *>())
+          forgetMemoizedResults(S);
+        else {
+          BackedgeTakenCounts.erase(LoopOrSCEV.get<const Loop *>());
+          PredicatedBackedgeTakenCounts.erase(LoopOrSCEV.get<const Loop *>());
+        }
+      }
       LoopUsers.erase(LoopUsersItr);
     }
 
@@ -6551,6 +6559,34 @@ bool ScalarEvolution::BackedgeTakenInfo:
   return false;
 }
 
+static void findUsedLoopsInSCEVExpr(const SCEV *S,
+                                    SmallPtrSetImpl<const Loop *> &Result) {
+  struct FindUsedLoops {
+    SmallPtrSetImpl<const Loop *> &LoopsUsed;
+    FindUsedLoops(SmallPtrSetImpl<const Loop *> &LoopsUsed)
+        : LoopsUsed(LoopsUsed) {}
+    bool follow(const SCEV *S) {
+      if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
+        LoopsUsed.insert(AR->getLoop());
+      return true;
+    }
+
+    bool isDone() const { return false; }
+  };
+  FindUsedLoops F(Result);
+  SCEVTraversal<FindUsedLoops>(F).visitAll(S);
+}
+
+void ScalarEvolution::BackedgeTakenInfo::findUsedLoops(
+    ScalarEvolution &SE, SmallPtrSetImpl<const Loop *> &Result) const {
+  if (auto *S = getMax())
+    if (S != SE.getCouldNotCompute())
+      findUsedLoopsInSCEVExpr(S, Result);
+  for (auto &ENT : ExitNotTaken)
+    if (ENT.ExactNotTaken != SE.getCouldNotCompute())
+      findUsedLoopsInSCEVExpr(ENT.ExactNotTaken, Result);
+}
+
 ScalarEvolution::ExitLimit::ExitLimit(const SCEV *E)
     : ExactNotTaken(E), MaxNotTaken(E) {
   assert((isa<SCEVCouldNotCompute>(MaxNotTaken) ||
@@ -11034,21 +11070,6 @@ ScalarEvolution::forgetMemoizedResults(c
       ++I;
   }
 
-  auto RemoveSCEVFromBackedgeMap =
-      [S, this](DenseMap<const Loop *, BackedgeTakenInfo> &Map) {
-        for (auto I = Map.begin(), E = Map.end(); I != E;) {
-          BackedgeTakenInfo &BEInfo = I->second;
-          if (BEInfo.hasOperand(S, this)) {
-            BEInfo.clear();
-            Map.erase(I++);
-          } else
-            ++I;
-        }
-      };
-
-  RemoveSCEVFromBackedgeMap(BackedgeTakenCounts);
-  RemoveSCEVFromBackedgeMap(PredicatedBackedgeTakenCounts);
-
   // TODO: There is a suspicion that we only need to do it when there is a
   // SCEVUnknown somewhere inside S. Need to check this.
   if (EraseExitLimit)
@@ -11058,22 +11079,19 @@ ScalarEvolution::forgetMemoizedResults(c
 }
 
 void ScalarEvolution::addToLoopUseLists(const SCEV *S) {
-  struct FindUsedLoops {
-    SmallPtrSet<const Loop *, 8> LoopsUsed;
-    bool follow(const SCEV *S) {
-      if (auto *AR = dyn_cast<SCEVAddRecExpr>(S))
-        LoopsUsed.insert(AR->getLoop());
-      return true;
-    }
-
-    bool isDone() const { return false; }
-  };
+  SmallPtrSet<const Loop *, 8> LoopsUsed;
+  findUsedLoopsInSCEVExpr(S, LoopsUsed);
+  for (auto *L : LoopsUsed)
+    LoopUsers[L].push_back({S});
+}
 
-  FindUsedLoops F;
-  SCEVTraversal<FindUsedLoops>(F).visitAll(S);
+void ScalarEvolution::addToLoopUseLists(
+    const ScalarEvolution::BackedgeTakenInfo &BTI, const Loop *L) {
+  SmallPtrSet<const Loop *, 8> LoopsUsed;
+  BTI.findUsedLoops(*this, LoopsUsed);
 
-  for (auto *L : F.LoopsUsed)
-    LoopUsers[L].push_back(S);
+  for (auto *UsedL : LoopsUsed)
+    LoopUsers[UsedL].push_back({L});
 }
 
 void ScalarEvolution::verify() const {

Modified: llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp?rev=315713&r1=315712&r2=315713&view=diff
==============================================================================
--- llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp (original)
+++ llvm/trunk/unittests/Analysis/ScalarEvolutionTest.cpp Fri Oct 13 10:13:44 2017
@@ -24,11 +24,19 @@
 #include "llvm/IR/Module.h"
 #include "llvm/IR/Verifier.h"
 #include "llvm/Support/SourceMgr.h"
+#include "gmock/gmock.h"
 #include "gtest/gtest.h"
 
 namespace llvm {
 namespace {
 
+MATCHER_P3(IsAffineAddRec, S, X, L, "") {
+  if (auto *AR = dyn_cast<SCEVAddRecExpr>(arg))
+    return AR->isAffine() && AR->getLoop() == L && AR->getOperand(0) == S &&
+           AR->getOperand(1) == X;
+  return false;
+}
+
 // We use this fixture to ensure that we clean up ScalarEvolution before
 // deleting the PassManager.
 class ScalarEvolutionsTest : public testing::Test {
@@ -886,90 +894,6 @@ TEST_F(ScalarEvolutionsTest, SCEVExitLim
             2004u);
 }
 
-// Make sure that SCEV invalidates exit limits after invalidating the values it
-// depends on when we forget a value.
-TEST_F(ScalarEvolutionsTest, SCEVExitLimitForgetValue) {
-  /*
-   * Create the following code:
-   * func(i64 addrspace(10)* %arg)
-   * top:
-   *  br label %L.ph
-   * L.ph:
-   *  %load = load i64 addrspace(10)* %arg
-   *  br label %L
-   * L:
-   *  %phi = phi i64 [i64 0, %L.ph], [ %add, %L2 ]
-   *  %add = add i64 %phi2, 1
-   *  %cond = icmp slt i64 %add, %load ; then becomes 2000.
-   *  br i1 %cond, label %post, label %L2
-   * post:
-   *  ret void
-   *
-   */
-
-  // Create a module with non-integral pointers in it's datalayout
-  Module NIM("nonintegral", Context);
-  std::string DataLayout = M.getDataLayoutStr();
-  if (!DataLayout.empty())
-    DataLayout += "-";
-  DataLayout += "ni:10";
-  NIM.setDataLayout(DataLayout);
-
-  Type *T_int64 = Type::getInt64Ty(Context);
-  Type *T_pint64 = T_int64->getPointerTo(10);
-
-  FunctionType *FTy =
-      FunctionType::get(Type::getVoidTy(Context), {T_pint64}, false);
-  Function *F = cast<Function>(NIM.getOrInsertFunction("foo", FTy));
-
-  Argument *Arg = &*F->arg_begin();
-
-  BasicBlock *Top = BasicBlock::Create(Context, "top", F);
-  BasicBlock *LPh = BasicBlock::Create(Context, "L.ph", F);
-  BasicBlock *L = BasicBlock::Create(Context, "L", F);
-  BasicBlock *Post = BasicBlock::Create(Context, "post", F);
-
-  IRBuilder<> Builder(Top);
-  Builder.CreateBr(LPh);
-
-  Builder.SetInsertPoint(LPh);
-  auto *Load = cast<Instruction>(Builder.CreateLoad(T_int64, Arg, "load"));
-  Builder.CreateBr(L);
-
-  Builder.SetInsertPoint(L);
-  PHINode *Phi = Builder.CreatePHI(T_int64, 2);
-  auto *Add = cast<Instruction>(
-      Builder.CreateAdd(Phi, ConstantInt::get(T_int64, 1), "add"));
-  auto *Cond = cast<Instruction>(
-      Builder.CreateICmp(ICmpInst::ICMP_SLT, Add, Load, "cond"));
-  auto *Br = cast<Instruction>(Builder.CreateCondBr(Cond, L, Post));
-  Phi->addIncoming(ConstantInt::get(T_int64, 0), LPh);
-  Phi->addIncoming(Add, L);
-
-  Builder.SetInsertPoint(Post);
-  Builder.CreateRetVoid();
-
-  ScalarEvolution SE = buildSE(*F);
-  auto *Loop = LI->getLoopFor(L);
-  const SCEV *EC = SE.getBackedgeTakenCount(Loop);
-  EXPECT_FALSE(isa<SCEVCouldNotCompute>(EC));
-  EXPECT_FALSE(isa<SCEVConstant>(EC));
-
-  SE.forgetValue(Load);
-  Br->eraseFromParent();
-  Cond->eraseFromParent();
-  Load->eraseFromParent();
-
-  Builder.SetInsertPoint(L);
-  auto *NewCond = Builder.CreateICmp(
-      ICmpInst::ICMP_SLT, Add, ConstantInt::get(T_int64, 2000), "new.cond");
-  Builder.CreateCondBr(NewCond, L, Post);
-  const SCEV *NewEC = SE.getBackedgeTakenCount(Loop);
-  EXPECT_FALSE(isa<SCEVCouldNotCompute>(NewEC));
-  EXPECT_TRUE(isa<SCEVConstant>(NewEC));
-  EXPECT_EQ(cast<SCEVConstant>(NewEC)->getAPInt().getLimitedValue(), 1999u);
-}
-
 TEST_F(ScalarEvolutionsTest, SCEVAddRecFromPHIwithLargeConstants) {
   // Reference: https://reviews.llvm.org/D37265
   // Make sure that SCEV does not blow up when constructing an AddRec
@@ -1082,6 +1006,75 @@ TEST_F(ScalarEvolutionsTest, SCEVAddRecF
   auto Result = SE.createAddRecFromPHIWithCasts(cast<SCEVUnknown>(Expr));
 }
 
+TEST_F(ScalarEvolutionsTest, SCEVForgetDependentLoop) {
+  LLVMContext C;
+  SMDiagnostic Err;
+  std::unique_ptr<Module> M = parseAssemblyString(
+      "target datalayout = \"e-m:e-p:32:32-f64:32:64-f80:32-n8:16:32-S128\" "
+      " "
+      "define void @f(i32 %first_limit, i1* %cond) { "
+      "entry: "
+      " br label %first_loop.ph "
+      " "
+      "first_loop.ph: "
+      "  br label %first_loop "
+      " "
+      "first_loop: "
+      "  %iv_first = phi i32 [0, %first_loop.ph], [%iv_first.inc, %first_loop] "
+      "  %iv_first.inc = add i32 %iv_first, 1 "
+      "  %known_cond = icmp slt i32 %iv_first, 2000 "
+      "  %unknown_cond = load volatile i1, i1* %cond "
+      "  br i1 %unknown_cond, label %first_loop, label %first_loop.exit "
+      " "
+      "first_loop.exit: "
+      "  %iv_first.3x = mul i32 %iv_first, 3 "
+      "  %iv_first.5x = mul i32 %iv_first, 5 "
+      "  br label %second_loop.ph "
+      " "
+      "second_loop.ph: "
+      "  br label %second_loop "
+      " "
+      "second_loop: "
+      "  %iv_second = phi i32 [%iv_first.3x, %second_loop.ph], [%iv_second.inc, %second_loop] "
+      "  %iv_second.inc = add i32 %iv_second, 1 "
+      "  %second_loop.cond = icmp ne i32 %iv_second, %iv_first.5x "
+      "  br i1 %second_loop.cond, label %second_loop, label %second_loop.exit "
+      " "
+      "second_loop.exit: "
+      "  ret void "
+      "} "
+      " ",
+      Err, C);
+
+  assert(M && "Could not parse module?");
+  assert(!verifyModule(*M) && "Must have been well formed!");
+
+  runWithSE(*M, "f", [&](Function &F, LoopInfo &LI, ScalarEvolution &SE) {
+    auto &FirstIV = GetInstByName(F, "iv_first");
+    auto &SecondIV = GetInstByName(F, "iv_second");
+
+    auto *FirstLoop = LI.getLoopFor(FirstIV.getParent());
+    auto *SecondLoop = LI.getLoopFor(SecondIV.getParent());
+
+    auto *Zero = SE.getZero(FirstIV.getType());
+    auto *Two = SE.getConstant(APInt(32, 2));
+
+    EXPECT_EQ(SE.getBackedgeTakenCount(FirstLoop), SE.getCouldNotCompute());
+    EXPECT_THAT(SE.getBackedgeTakenCount(SecondLoop),
+                IsAffineAddRec(Zero, Two, FirstLoop));
+
+    auto &UnknownCond = GetInstByName(F, "unknown_cond");
+    auto &KnownCond = GetInstByName(F, "known_cond");
+
+    UnknownCond.replaceAllUsesWith(&KnownCond);
+
+    SE.forgetLoop(FirstLoop);
+
+    EXPECT_EQ(SE.getBackedgeTakenCount(FirstLoop), SE.getConstant(APInt(32, 2000)));
+    EXPECT_EQ(SE.getBackedgeTakenCount(SecondLoop), SE.getConstant(APInt(32, 4000)));
+  });
+}
+
 TEST_F(ScalarEvolutionsTest, SCEVFoldSumOfTruncs) {
   // Verify that the following SCEV gets folded to a zero:
   //  (-1 * (trunc i64 (-1 * %0) to i32)) + (-1 * (trunc i64 %0 to i32)




More information about the llvm-commits mailing list