[llvm] r336183 - [PM/LoopUnswitch] Fix PR37651 by correctly invalidating SCEV when

Chandler Carruth via llvm-commits llvm-commits at lists.llvm.org
Tue Jul 3 02:13:27 PDT 2018


Author: chandlerc
Date: Tue Jul  3 02:13:27 2018
New Revision: 336183

URL: http://llvm.org/viewvc/llvm-project?rev=336183&view=rev
Log:
[PM/LoopUnswitch] Fix PR37651 by correctly invalidating SCEV when
unswitching loops.

Original patch trying to address this was sent in D47624, but that
didn't quite handle things correctly. There are two key principles used
to select whether and how to invalidate SCEV-cached information about
loops:

1) We must invalidate any info SCEV has cached before unswitching as we
   may change (or destroy) the loop structure by the act of unswitching,
   and make it hard to recover everything we want to invalidate within
   SCEV.

2) We need to invalidate all of the loops whose CFGs are mutated by the
   unswitching. Notably, this isn't the *entire* loop nest, this is
   every loop contained by the outermost loop reached by an exit block
   relevant to the unswitch.

And we need to do this even when doing trivial unswitching.

I've added more focused tests that directly check that SCEV starts off
with imprecise information and after unswitching (and simplifying
instructions) re-querying SCEV will produce precise information. These
tests also specifically work to check that an *outer* loop's information
becomes precise.

However, the testing here is still a bit imperfect. Crafting test cases
that reliably fail to be analyzed by SCEV before unswitching and succeed
afterward proved ... very, very hard. It took me several hours and
careful work to build these, and I'm not optimistic about necessarily
coming up with more to cover more elaborate possibilities. Fortunately,
the code pattern we are testing here in the pass is really
straightforward and reliable.

Thanks to Max Kazantsev for the initial work on this as well as the
review, and to Hal Finkel for helping me talk through approaches to test
this stuff even if it didn't come to much.

Differential Revision: https://reviews.llvm.org/D47624

Added:
    llvm/trunk/test/Transforms/SimpleLoopUnswitch/update-scev.ll
Modified:
    llvm/trunk/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp

Modified: llvm/trunk/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp?rev=336183&r1=336182&r2=336183&view=diff
==============================================================================
--- llvm/trunk/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp (original)
+++ llvm/trunk/lib/Transforms/Scalar/SimpleLoopUnswitch.cpp Tue Jul  3 02:13:27 2018
@@ -253,8 +253,11 @@ static void rewritePHINodesForExitAndUns
 /// (splitting the exit block as necessary). It simplifies the branch within
 /// the loop to an unconditional branch but doesn't remove it entirely. Further
 /// cleanup can be done with some simplify-cfg like pass.
+///
+/// If `SE` is not null, it will be updated based on the potential loop SCEVs
+/// invalidated by this.
 static bool unswitchTrivialBranch(Loop &L, BranchInst &BI, DominatorTree &DT,
-                                  LoopInfo &LI) {
+                                  LoopInfo &LI, ScalarEvolution *SE) {
   assert(BI.isConditional() && "Can only unswitch a conditional branch!");
   LLVM_DEBUG(dbgs() << "  Trying to unswitch branch: " << BI << "\n");
 
@@ -318,6 +321,16 @@ static bool unswitchTrivialBranch(Loop &
     }
   });
 
+  // If we have scalar evolutions, we need to invalidate them including this
+  // loop and the loop containing the exit block.
+  if (SE) {
+    if (Loop *ExitL = LI.getLoopFor(LoopExitBB))
+      SE->forgetLoop(ExitL);
+    else
+      // Forget the entire nest as this exits the entire nest.
+      SE->forgetTopmostLoop(&L);
+  }
+
   // Split the preheader, so that we know that there is a safe place to insert
   // the conditional branch. We will change the preheader to have a conditional
   // branch on LoopCond.
@@ -420,8 +433,11 @@ static bool unswitchTrivialBranch(Loop &
 /// switch will not be revisited. If after unswitching there is only a single
 /// in-loop successor, the switch is further simplified to an unconditional
 /// branch. Still more cleanup can be done with some simplify-cfg like pass.
+///
+/// If `SE` is not null, it will be updated based on the potential loop SCEVs
+/// invalidated by this.
 static bool unswitchTrivialSwitch(Loop &L, SwitchInst &SI, DominatorTree &DT,
-                                  LoopInfo &LI) {
+                                  LoopInfo &LI, ScalarEvolution *SE) {
   LLVM_DEBUG(dbgs() << "  Trying to unswitch switch: " << SI << "\n");
   Value *LoopCond = SI.getCondition();
 
@@ -448,18 +464,33 @@ static bool unswitchTrivialSwitch(Loop &
 
   LLVM_DEBUG(dbgs() << "    unswitching trivial cases...\n");
 
+  // We may need to invalidate SCEVs for the outermost loop reached by any of
+  // the exits.
+  Loop *OuterL = &L;
+
   SmallVector<std::pair<ConstantInt *, BasicBlock *>, 4> ExitCases;
   ExitCases.reserve(ExitCaseIndices.size());
   // We walk the case indices backwards so that we remove the last case first
   // and don't disrupt the earlier indices.
   for (unsigned Index : reverse(ExitCaseIndices)) {
     auto CaseI = SI.case_begin() + Index;
+    // Compute the outer loop from this exit.
+    Loop *ExitL = LI.getLoopFor(CaseI->getCaseSuccessor());
+    if (!ExitL || ExitL->contains(OuterL))
+      OuterL = ExitL;
     // Save the value of this case.
     ExitCases.push_back({CaseI->getCaseValue(), CaseI->getCaseSuccessor()});
     // Delete the unswitched cases.
     SI.removeCase(CaseI);
   }
 
+  if (SE) {
+    if (OuterL)
+      SE->forgetLoop(OuterL);
+    else
+      SE->forgetTopmostLoop(&L);
+  }
+
   // Check if after this all of the remaining cases point at the same
   // successor.
   BasicBlock *CommonSuccBB = nullptr;
@@ -617,8 +648,11 @@ static bool unswitchTrivialSwitch(Loop &
 ///
 /// The return value indicates whether anything was unswitched (and therefore
 /// changed).
+///
+/// If `SE` is not null, it will be updated based on the potential loop SCEVs
+/// invalidated by this.
 static bool unswitchAllTrivialConditions(Loop &L, DominatorTree &DT,
-                                         LoopInfo &LI) {
+                                         LoopInfo &LI, ScalarEvolution *SE) {
   bool Changed = false;
 
   // If loop header has only one reachable successor we should keep looking for
@@ -652,7 +686,7 @@ static bool unswitchAllTrivialConditions
       if (isa<Constant>(SI->getCondition()))
         return Changed;
 
-      if (!unswitchTrivialSwitch(L, *SI, DT, LI))
+      if (!unswitchTrivialSwitch(L, *SI, DT, LI, SE))
         // Couldn't unswitch this one so we're done.
         return Changed;
 
@@ -684,7 +718,7 @@ static bool unswitchAllTrivialConditions
 
     // Found a trivial condition candidate: non-foldable conditional branch. If
     // we fail to unswitch this, we can't do anything else that is trivial.
-    if (!unswitchTrivialBranch(L, *BI, DT, LI))
+    if (!unswitchTrivialBranch(L, *BI, DT, LI, SE))
       return Changed;
 
     // Mark that we managed to unswitch something.
@@ -1622,7 +1656,8 @@ void visitDomSubTree(DominatorTree &DT,
 static bool unswitchNontrivialInvariants(
     Loop &L, TerminatorInst &TI, ArrayRef<Value *> Invariants,
     DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
-    function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) {
+    function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB,
+    ScalarEvolution *SE) {
   auto *ParentBB = TI.getParent();
   BranchInst *BI = dyn_cast<BranchInst>(&TI);
   SwitchInst *SI = BI ? nullptr : cast<SwitchInst>(&TI);
@@ -1705,6 +1740,16 @@ static bool unswitchNontrivialInvariants
       OuterExitL = NewOuterExitL;
   }
 
+  // At this point, we're definitely going to unswitch something so invalidate
+  // any cached information in ScalarEvolution for the outer most loop
+  // containing an exit block and all nested loops.
+  if (SE) {
+    if (OuterExitL)
+      SE->forgetLoop(OuterExitL);
+    else
+      SE->forgetTopmostLoop(&L);
+  }
+
   // If the edge from this terminator to a successor dominates that successor,
   // store a map from each block in its dominator subtree to it. This lets us
   // tell when cloning for a particular successor if a block is dominated by
@@ -1968,10 +2013,11 @@ computeDomSubtreeCost(DomTreeNode &N,
   return Cost;
 }
 
-static bool unswitchBestCondition(
-    Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
-    TargetTransformInfo &TTI,
-    function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) {
+static bool
+unswitchBestCondition(Loop &L, DominatorTree &DT, LoopInfo &LI,
+                      AssumptionCache &AC, TargetTransformInfo &TTI,
+                      function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB,
+                      ScalarEvolution *SE) {
   // Collect all invariant conditions within this loop (as opposed to an inner
   // loop which would be handled when visiting that inner loop).
   SmallVector<std::pair<TerminatorInst *, TinyPtrVector<Value *>>, 4>
@@ -2164,7 +2210,7 @@ static bool unswitchBestCondition(
                     << BestUnswitchCost << ") terminator: " << *BestUnswitchTI
                     << "\n");
   return unswitchNontrivialInvariants(
-      L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB);
+      L, *BestUnswitchTI, BestUnswitchInvariants, DT, LI, AC, UnswitchCB, SE);
 }
 
 /// Unswitch control flow predicated on loop invariant conditions.
@@ -2173,10 +2219,25 @@ static bool unswitchBestCondition(
 /// require duplicating any part of the loop) out of the loop body. It then
 /// looks at other loop invariant control flows and tries to unswitch those as
 /// well by cloning the loop if the result is small enough.
-static bool
-unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI, AssumptionCache &AC,
-             TargetTransformInfo &TTI, bool NonTrivial,
-             function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB) {
+///
+/// The `DT`, `LI`, `AC`, `TTI` parameters are required analyses that are also
+/// updated based on the unswitch.
+///
+/// If either `NonTrivial` is true or the flag `EnableNonTrivialUnswitch` is
+/// true, we will attempt to do non-trivial unswitching as well as trivial
+/// unswitching.
+///
+/// The `UnswitchCB` callback provided will be run after unswitching is
+/// complete, with the first parameter set to `true` if the provided loop
+/// remains a loop, and a list of new sibling loops created.
+///
+/// If `SE` is non-null, we will update that analysis based on the unswitching
+/// done.
+static bool unswitchLoop(Loop &L, DominatorTree &DT, LoopInfo &LI,
+                         AssumptionCache &AC, TargetTransformInfo &TTI,
+                         bool NonTrivial,
+                         function_ref<void(bool, ArrayRef<Loop *>)> UnswitchCB,
+                         ScalarEvolution *SE) {
   assert(L.isRecursivelyLCSSAForm(DT, LI) &&
          "Loops must be in LCSSA form before unswitching.");
   bool Changed = false;
@@ -2186,7 +2247,7 @@ unswitchLoop(Loop &L, DominatorTree &DT,
     return false;
 
   // Try trivial unswitch first before loop over other basic blocks in the loop.
-  if (unswitchAllTrivialConditions(L, DT, LI)) {
+  if (unswitchAllTrivialConditions(L, DT, LI, SE)) {
     // If we unswitched successfully we will want to clean up the loop before
     // processing it further so just mark it as unswitched and return.
     UnswitchCB(/*CurrentLoopValid*/ true, {});
@@ -2207,7 +2268,7 @@ unswitchLoop(Loop &L, DominatorTree &DT,
 
   // Try to unswitch the best invariant condition. We prefer this full unswitch to
   // a partial unswitch when possible below the threshold.
-  if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB))
+  if (unswitchBestCondition(L, DT, LI, AC, TTI, UnswitchCB, SE))
     return true;
 
   // No other opportunities to unswitch.
@@ -2241,8 +2302,8 @@ PreservedAnalyses SimpleLoopUnswitchPass
       U.markLoopAsDeleted(L, LoopName);
   };
 
-  if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial,
-                    UnswitchCB))
+  if (!unswitchLoop(L, AR.DT, AR.LI, AR.AC, AR.TTI, NonTrivial, UnswitchCB,
+                    &AR.SE))
     return PreservedAnalyses::all();
 
   // Historically this pass has had issues with the dominator tree so verify it
@@ -2290,6 +2351,9 @@ bool SimpleLoopUnswitchLegacyPass::runOn
   auto &AC = getAnalysis<AssumptionCacheTracker>().getAssumptionCache(F);
   auto &TTI = getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
 
+  auto *SEWP = getAnalysisIfAvailable<ScalarEvolutionWrapperPass>();
+  auto *SE = SEWP ? &SEWP->getSE() : nullptr;
+
   auto UnswitchCB = [&L, &LPM](bool CurrentLoopValid,
                                ArrayRef<Loop *> NewLoops) {
     // If we did a non-trivial unswitch, we have added new (cloned) loops.
@@ -2305,8 +2369,7 @@ bool SimpleLoopUnswitchLegacyPass::runOn
       LPM.markLoopAsDeleted(*L);
   };
 
-  bool Changed =
-      unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB);
+  bool Changed = unswitchLoop(*L, DT, LI, AC, TTI, NonTrivial, UnswitchCB, SE);
 
   // If anything was unswitched, also clear any cached information about this
   // loop.

Added: llvm/trunk/test/Transforms/SimpleLoopUnswitch/update-scev.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/SimpleLoopUnswitch/update-scev.ll?rev=336183&view=auto
==============================================================================
--- llvm/trunk/test/Transforms/SimpleLoopUnswitch/update-scev.ll (added)
+++ llvm/trunk/test/Transforms/SimpleLoopUnswitch/update-scev.ll Tue Jul  3 02:13:27 2018
@@ -0,0 +1,188 @@
+; RUN: opt -passes='print<scalar-evolution>,loop(unswitch,loop-instsimplify),print<scalar-evolution>' -enable-nontrivial-unswitch -S < %s 2>%t.scev | FileCheck %s
+; RUN: FileCheck %s --check-prefix=SCEV < %t.scev
+
+target triple = "x86_64-unknown-linux-gnu"
+
+declare void @f()
+
+; Check that trivially unswitching an inner loop resets both the inner and outer
+; loop trip count.
+define void @test1(i32 %n, i32 %m, i1 %cond) {
+; Check that SCEV has no trip count before unswitching.
+; SCEV-LABEL: Determining loop execution counts for: @test1
+; SCEV: Loop %inner_loop_begin: <multiple exits> Unpredictable backedge-taken count.
+; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count.
+;
+; Now check that after unswitching and simplifying instructions we get clean
+; backedge-taken counts.
+; SCEV-LABEL: Determining loop execution counts for: @test1
+; SCEV: Loop %inner_loop_begin: backedge-taken count is (-1 + (1 smax %m))<nsw>
+; SCEV: Loop %outer_loop_begin: backedge-taken count is (-1 + (1 smax %n))<nsw>
+;
+; And verify the code matches what we expect.
+; CHECK-LABEL: define void @test1(
+entry:
+  br label %outer_loop_begin
+; Ensure the outer loop didn't get unswitched.
+; CHECK:       entry:
+; CHECK-NEXT:    br label %outer_loop_begin
+
+outer_loop_begin:
+  %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ]
+  ; Block unswitching of the outer loop with a noduplicate call.
+  call void @f() noduplicate
+  br label %inner_loop_begin
+; Ensure the inner loop got unswitched into the outer loop.
+; CHECK:       outer_loop_begin:
+; CHECK-NEXT:    %{{.*}} = phi i32
+; CHECK-NEXT:    call void @f()
+; CHECK-NEXT:    br i1 %cond,
+
+inner_loop_begin:
+  %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ]
+  br i1 %cond, label %inner_loop_latch, label %inner_loop_early_exit
+
+inner_loop_latch:
+  %j.next = add nsw i32 %j, 1
+  %j.cmp = icmp slt i32 %j.next, %m
+  br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit
+
+inner_loop_early_exit:
+  %j.lcssa = phi i32 [ %i, %inner_loop_begin ]
+  br label %outer_loop_latch
+
+inner_loop_late_exit:
+  br label %outer_loop_latch
+
+outer_loop_latch:
+  %i.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ %i, %inner_loop_late_exit ]
+  %i.next = add nsw i32 %i.phi, 1
+  %i.cmp = icmp slt i32 %i.next, %n
+  br i1 %i.cmp, label %outer_loop_begin, label %exit
+
+exit:
+  ret void
+}
+
+; Check that trivially unswitching an inner loop resets both the inner and outer
+; loop trip count.
+define void @test2(i32 %n, i32 %m, i32 %cond) {
+; Check that SCEV has no trip count before unswitching.
+; SCEV-LABEL: Determining loop execution counts for: @test2
+; SCEV: Loop %inner_loop_begin: <multiple exits> Unpredictable backedge-taken count.
+; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count.
+;
+; Now check that after unswitching and simplifying instructions we get clean
+; backedge-taken counts.
+; SCEV-LABEL: Determining loop execution counts for: @test2
+; SCEV: Loop %inner_loop_begin: backedge-taken count is (-1 + (1 smax %m))<nsw>
+; FIXME: The following backedge taken count should be known but isn't apparently
+; just because of a switch in the outer loop.
+; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count.
+;
+; CHECK-LABEL: define void @test2(
+entry:
+  br label %outer_loop_begin
+; Ensure the outer loop didn't get unswitched.
+; CHECK:       entry:
+; CHECK-NEXT:    br label %outer_loop_begin
+
+outer_loop_begin:
+  %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ]
+  ; Block unswitching of the outer loop with a noduplicate call.
+  call void @f() noduplicate
+  br label %inner_loop_begin
+; Ensure the inner loop got unswitched into the outer loop.
+; CHECK:       outer_loop_begin:
+; CHECK-NEXT:    %{{.*}} = phi i32
+; CHECK-NEXT:    call void @f()
+; CHECK-NEXT:    switch i32 %cond,
+
+inner_loop_begin:
+  %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ]
+  switch i32 %cond, label %inner_loop_early_exit [
+    i32 1, label %inner_loop_latch
+    i32 2, label %inner_loop_latch
+  ]
+
+inner_loop_latch:
+  %j.next = add nsw i32 %j, 1
+  %j.cmp = icmp slt i32 %j.next, %m
+  br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit
+
+inner_loop_early_exit:
+  %j.lcssa = phi i32 [ %i, %inner_loop_begin ]
+  br label %outer_loop_latch
+
+inner_loop_late_exit:
+  br label %outer_loop_latch
+
+outer_loop_latch:
+  %i.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ %i, %inner_loop_late_exit ]
+  %i.next = add nsw i32 %i.phi, 1
+  %i.cmp = icmp slt i32 %i.next, %n
+  br i1 %i.cmp, label %outer_loop_begin, label %exit
+
+exit:
+  ret void
+}
+
+; Check that non-trivial unswitching of a branch in an inner loop into the outer
+; loop invalidates both inner and outer.
+define void @test3(i32 %n, i32 %m, i1 %cond) {
+; Check that SCEV has no trip count before unswitching.
+; SCEV-LABEL: Determining loop execution counts for: @test3
+; SCEV: Loop %inner_loop_begin: <multiple exits> Unpredictable backedge-taken count.
+; SCEV: Loop %outer_loop_begin: Unpredictable backedge-taken count.
+;
+; Now check that after unswitching and simplifying instructions we get clean
+; backedge-taken counts.
+; SCEV-LABEL: Determining loop execution counts for: @test3
+; SCEV: Loop %inner_loop_begin{{.*}}: backedge-taken count is (-1 + (1 smax %m))<nsw>
+; SCEV: Loop %outer_loop_begin: backedge-taken count is (-1 + (1 smax %n))<nsw>
+;
+; And verify the code matches what we expect.
+; CHECK-LABEL: define void @test3(
+entry:
+  br label %outer_loop_begin
+; Ensure the outer loop didn't get unswitched.
+; CHECK:       entry:
+; CHECK-NEXT:    br label %outer_loop_begin
+
+outer_loop_begin:
+  %i = phi i32 [ %i.next, %outer_loop_latch ], [ 0, %entry ]
+  ; Block unswitching of the outer loop with a noduplicate call.
+  call void @f() noduplicate
+  br label %inner_loop_begin
+; Ensure the inner loop got unswitched into the outer loop.
+; CHECK:       outer_loop_begin:
+; CHECK-NEXT:    %{{.*}} = phi i32
+; CHECK-NEXT:    call void @f()
+; CHECK-NEXT:    br i1 %cond,
+
+inner_loop_begin:
+  %j = phi i32 [ %j.next, %inner_loop_latch ], [ 0, %outer_loop_begin ]
+  %j.tmp = add nsw i32 %j, 1
+  br i1 %cond, label %inner_loop_latch, label %inner_loop_early_exit
+
+inner_loop_latch:
+  %j.next = add nsw i32 %j, 1
+  %j.cmp = icmp slt i32 %j.next, %m
+  br i1 %j.cmp, label %inner_loop_begin, label %inner_loop_late_exit
+
+inner_loop_early_exit:
+  %j.lcssa = phi i32 [ %j.tmp, %inner_loop_begin ]
+  br label %outer_loop_latch
+
+inner_loop_late_exit:
+  br label %outer_loop_latch
+
+outer_loop_latch:
+  %inc.phi = phi i32 [ %j.lcssa, %inner_loop_early_exit ], [ 1, %inner_loop_late_exit ]
+  %i.next = add nsw i32 %i, %inc.phi
+  %i.cmp = icmp slt i32 %i.next, %n
+  br i1 %i.cmp, label %outer_loop_begin, label %exit
+
+exit:
+  ret void
+}




More information about the llvm-commits mailing list