[llvm] f7c54c4 - [LoopUnroll] Fold all exits based on known trip count/multiple

Nikita Popov via llvm-commits llvm-commits at lists.llvm.org
Thu Jun 17 11:58:43 PDT 2021


Author: Nikita Popov
Date: 2021-06-17T20:58:34+02:00
New Revision: f7c54c4603a2df8c0833f5ddc04a5e109bca1c14

URL: https://github.com/llvm/llvm-project/commit/f7c54c4603a2df8c0833f5ddc04a5e109bca1c14
DIFF: https://github.com/llvm/llvm-project/commit/f7c54c4603a2df8c0833f5ddc04a5e109bca1c14.diff

LOG: [LoopUnroll] Fold all exits based on known trip count/multiple

Fold all exits based on known trip count/multiple information from
SCEV. Previously only the latch exit or the single exit were folded.

This doesn't yet eliminate ULO.TripCount and ULO.TripMultiple
entirely: They're still used to a) decide whether runtime unrolling
should be performed and b) for ORE remarks. However, the core
unrolling logic is independent of them now.

Differential Revision: https://reviews.llvm.org/D104203

Added: 
    

Modified: 
    llvm/lib/Transforms/Utils/LoopUnroll.cpp
    llvm/test/Transforms/LoopUnroll/multiple-exits.ll
    llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll

Removed: 
    


################################################################################
diff  --git a/llvm/lib/Transforms/Utils/LoopUnroll.cpp b/llvm/lib/Transforms/Utils/LoopUnroll.cpp
index b3658fbe9e1c6..4d6112787a526 100644
--- a/llvm/lib/Transforms/Utils/LoopUnroll.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnroll.cpp
@@ -328,6 +328,37 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
   if (MaxTripCount && ULO.Count > MaxTripCount)
     ULO.Count = MaxTripCount;
 
+  struct ExitInfo {
+    unsigned TripCount;
+    unsigned TripMultiple;
+    unsigned BreakoutTrip;
+    bool ExitOnTrue;
+    SmallVector<BasicBlock *> ExitingBlocks;
+  };
+  DenseMap<BasicBlock *, ExitInfo> ExitInfos;
+  SmallVector<BasicBlock *, 4> ExitingBlocks;
+  L->getExitingBlocks(ExitingBlocks);
+  for (auto *ExitingBlock : ExitingBlocks) {
+    // The folding code is not prepared to deal with non-branch instructions
+    // right now.
+    auto *BI = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
+    if (!BI)
+      continue;
+
+    ExitInfo &Info = ExitInfos.try_emplace(ExitingBlock).first->second;
+    Info.TripCount = SE->getSmallConstantTripCount(L, ExitingBlock);
+    Info.TripMultiple = SE->getSmallConstantTripMultiple(L, ExitingBlock);
+    if (Info.TripCount != 0) {
+      Info.BreakoutTrip = Info.TripCount % ULO.Count;
+      Info.TripMultiple = 0;
+    } else {
+      Info.BreakoutTrip = Info.TripMultiple =
+          (unsigned)GreatestCommonDivisor64(ULO.Count, Info.TripMultiple);
+    }
+    Info.ExitOnTrue = !L->contains(BI->getSuccessor(0));
+    Info.ExitingBlocks.push_back(ExitingBlock);
+  }
+
   // Are we eliminating the loop control altogether?  Note that we can know
   // we're eliminating the backedge without knowing exactly which iteration
   // of the unrolled body exits.
@@ -362,31 +393,12 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
 
   // A conditional branch which exits the loop, which can be optimized to an
   // unconditional branch in the unrolled loop in some cases.
-  BranchInst *ExitingBI = nullptr;
   bool LatchIsExiting = L->isLoopExiting(LatchBlock);
-  if (LatchIsExiting)
-    ExitingBI = LatchBI;
-  else if (BasicBlock *ExitingBlock = L->getExitingBlock())
-    ExitingBI = dyn_cast<BranchInst>(ExitingBlock->getTerminator());
   if (!LatchBI || (LatchBI->isConditional() && !LatchIsExiting)) {
     LLVM_DEBUG(
         dbgs() << "Can't unroll; a conditional latch must exit the loop");
     return LoopUnrollResult::Unmodified;
   }
-  LLVM_DEBUG({
-    if (ExitingBI)
-      dbgs() << "  Exiting Block = " << ExitingBI->getParent()->getName()
-             << "\n";
-    else
-      dbgs() << "  No single exiting block\n";
-  });
-
-  // Warning: ExactTripCount is the exact trip count for the block ending in
-  // ExitingBI, not neccessarily an exact exit count *for the loop*.  The
-  // distinction comes when we have an exiting latch, but the loop exits
-  // through another exit first.
-  const unsigned ExactTripCount = ExitingBI ?
-    SE->getSmallConstantTripCount(L,ExitingBI->getParent()) : 0;
 
   // Loops containing convergent instructions must have a count that divides
   // their TripMultiple.
@@ -421,6 +433,7 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
   }
 
   // If we know the trip count, we know the multiple...
+  // TODO: This is only used for the ORE code, remove it.
   unsigned BreakoutTrip = 0;
   if (ULO.TripCount != 0) {
     BreakoutTrip = ULO.TripCount % ULO.Count;
@@ -504,12 +517,9 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
   }
 
   std::vector<BasicBlock *> Headers;
-  std::vector<BasicBlock *> ExitingBlocks;
   std::vector<BasicBlock *> Latches;
   Headers.push_back(Header);
   Latches.push_back(LatchBlock);
-  if (ExitingBI)
-    ExitingBlocks.push_back(ExitingBI->getParent());
 
   // The current on-the-fly SSA update requires blocks to be processed in
   // reverse postorder so that LastValueMap contains the correct value at each
@@ -609,9 +619,9 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
 
       // Keep track of the exiting block and its successor block contained in
       // the loop for the current iteration.
-      if (ExitingBI)
-        if (*BB == ExitingBlocks[0])
-          ExitingBlocks.push_back(New);
+      auto ExitInfoIt = ExitInfos.find(*BB);
+      if (ExitInfoIt != ExitInfos.end())
+        ExitInfoIt->second.ExitingBlocks.push_back(New);
 
       NewBlocks.push_back(New);
       UnrolledLoopBlocks.push_back(New);
@@ -701,71 +711,79 @@ LoopUnrollResult llvm::UnrollLoop(Loop *L, UnrollLoopOptions ULO, LoopInfo *LI,
 
   DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
 
-  if (ExitingBI) {
-    auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) {
-      auto *Term = cast<BranchInst>(Src->getTerminator());
-      const unsigned Idx = ExitOnTrue ^ WillExit;
-      BasicBlock *Dest = Term->getSuccessor(Idx);
-      BasicBlock *DeadSucc = Term->getSuccessor(1-Idx);
+  auto SetDest = [&](BasicBlock *Src, bool WillExit, bool ExitOnTrue) {
+    auto *Term = cast<BranchInst>(Src->getTerminator());
+    const unsigned Idx = ExitOnTrue ^ WillExit;
+    BasicBlock *Dest = Term->getSuccessor(Idx);
+    BasicBlock *DeadSucc = Term->getSuccessor(1-Idx);
 
-      // Remove predecessors from all non-Dest successors.
-      DeadSucc->removePredecessor(Src, /* KeepOneInputPHIs */ true);
+    // Remove predecessors from all non-Dest successors.
+    DeadSucc->removePredecessor(Src, /* KeepOneInputPHIs */ true);
 
-      // Replace the conditional branch with an unconditional one.
-      BranchInst::Create(Dest, Term);
-      Term->eraseFromParent();
+    // Replace the conditional branch with an unconditional one.
+    BranchInst::Create(Dest, Term);
+    Term->eraseFromParent();
 
-      DTU.applyUpdates({{DominatorTree::Delete, Src, DeadSucc}});
-    };
+    DTU.applyUpdates({{DominatorTree::Delete, Src, DeadSucc}});
+  };
 
-    auto WillExit = [&](unsigned i, unsigned j) -> Optional<bool> {
-      if (CompletelyUnroll) {
-        if (PreserveOnlyFirst) {
-          if (i == 0)
-            return None;
-          return j == 0;
-        }
-        // Complete (but possibly inexact) unrolling
-        if (j == 0)
-          return true;
-        // Warning: ExactTripCount is the trip count of the exiting
-        // block which ends in ExitingBI, not neccessarily the loop.
-        if (ExactTripCount && j != ExactTripCount)
-          return false;
-        return None;
+  auto WillExit = [&](const ExitInfo &Info, unsigned i, unsigned j,
+                      bool IsLatch) -> Optional<bool> {
+    if (CompletelyUnroll) {
+      if (PreserveOnlyFirst) {
+        if (i == 0)
+          return None;
+        return j == 0;
       }
-
-      if (RuntimeTripCount && j != 0)
+      // Complete (but possibly inexact) unrolling
+      if (j == 0)
+        return true;
+      if (Info.TripCount && j != Info.TripCount)
         return false;
+      return None;
+    }
 
-      if (j != BreakoutTrip &&
-          (ULO.TripMultiple == 0 || j % ULO.TripMultiple != 0)) {
-        // If we know the trip count or a multiple of it, we can safely use an
-        // unconditional branch for some iterations.
+    if (RuntimeTripCount) {
+      // If runtime unrolling inserts a prologue, information about non-latch
+      // exits may be stale.
+      if (IsLatch && j != 0)
         return false;
-      }
       return None;
-    };
+    }
+
+    if (j != Info.BreakoutTrip &&
+        (Info.TripMultiple == 0 || j % Info.TripMultiple != 0)) {
+      // If we know the trip count or a multiple of it, we can safely use an
+      // unconditional branch for some iterations.
+      return false;
+    }
+    return None;
+  };
 
-    // Fold branches for iterations where we know that they will exit or not
-    // exit.
-    bool ExitOnTrue = !L->contains(ExitingBI->getSuccessor(0));
-    for (unsigned i = 0, e = ExitingBlocks.size(); i != e; ++i) {
+  // Fold branches for iterations where we know that they will exit or not
+  // exit.
+  for (const auto &Pair : ExitInfos) {
+    const ExitInfo &Info = Pair.second;
+    for (unsigned i = 0, e = Info.ExitingBlocks.size(); i != e; ++i) {
       // The branch destination.
       unsigned j = (i + 1) % e;
-      Optional<bool> KnownWillExit = WillExit(i, j);
+      bool IsLatch = Pair.first == LatchBlock;
+      Optional<bool> KnownWillExit = WillExit(Info, i, j, IsLatch);
       if (!KnownWillExit)
         continue;
 
-      // TODO: Also fold known-exiting branches for non-latch exits.
-      if (*KnownWillExit && !LatchIsExiting)
+      // We don't fold known-exiting branches for non-latch exits here,
+      // because this ensures that both all loop blocks and all exit blocks
+      // remain reachable in the CFG.
+      // TODO: We could fold these branches, but it would require much more
+      // sophisticated updates to LoopInfo.
+      if (*KnownWillExit && !IsLatch)
         continue;
 
-      SetDest(ExitingBlocks[i], *KnownWillExit, ExitOnTrue);
+      SetDest(Info.ExitingBlocks[i], *KnownWillExit, Info.ExitOnTrue);
     }
   }
 
-
   // When completely unrolling, the last latch becomes unreachable.
   if (!LatchIsExiting && CompletelyUnroll)
     changeToUnreachable(Latches.back()->getTerminator(), /* UseTrap */ false,

diff  --git a/llvm/test/Transforms/LoopUnroll/multiple-exits.ll b/llvm/test/Transforms/LoopUnroll/multiple-exits.ll
index 8a3f51a1fb94e..decba5c654ddc 100644
--- a/llvm/test/Transforms/LoopUnroll/multiple-exits.ll
+++ b/llvm/test/Transforms/LoopUnroll/multiple-exits.ll
@@ -9,49 +9,49 @@ define void @test1() {
 ; CHECK-NEXT:    br label [[LOOP:%.*]]
 ; CHECK:       loop:
 ; CHECK-NEXT:    call void @bar()
-; CHECK-NEXT:    br i1 true, label [[LATCH:%.*]], label [[EXIT:%.*]]
+; CHECK-NEXT:    br label [[LATCH:%.*]]
 ; CHECK:       latch:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    call void @bar()
-; CHECK-NEXT:    br i1 true, label [[LATCH_1:%.*]], label [[EXIT]]
+; CHECK-NEXT:    br label [[LATCH_1:%.*]]
 ; CHECK:       exit:
 ; CHECK-NEXT:    ret void
 ; CHECK:       latch.1:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    call void @bar()
-; CHECK-NEXT:    br i1 true, label [[LATCH_2:%.*]], label [[EXIT]]
+; CHECK-NEXT:    br label [[LATCH_2:%.*]]
 ; CHECK:       latch.2:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    call void @bar()
-; CHECK-NEXT:    br i1 true, label [[LATCH_3:%.*]], label [[EXIT]]
+; CHECK-NEXT:    br label [[LATCH_3:%.*]]
 ; CHECK:       latch.3:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    call void @bar()
-; CHECK-NEXT:    br i1 true, label [[LATCH_4:%.*]], label [[EXIT]]
+; CHECK-NEXT:    br label [[LATCH_4:%.*]]
 ; CHECK:       latch.4:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    call void @bar()
-; CHECK-NEXT:    br i1 true, label [[LATCH_5:%.*]], label [[EXIT]]
+; CHECK-NEXT:    br label [[LATCH_5:%.*]]
 ; CHECK:       latch.5:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    call void @bar()
-; CHECK-NEXT:    br i1 true, label [[LATCH_6:%.*]], label [[EXIT]]
+; CHECK-NEXT:    br label [[LATCH_6:%.*]]
 ; CHECK:       latch.6:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    call void @bar()
-; CHECK-NEXT:    br i1 true, label [[LATCH_7:%.*]], label [[EXIT]]
+; CHECK-NEXT:    br label [[LATCH_7:%.*]]
 ; CHECK:       latch.7:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    call void @bar()
-; CHECK-NEXT:    br i1 true, label [[LATCH_8:%.*]], label [[EXIT]]
+; CHECK-NEXT:    br label [[LATCH_8:%.*]]
 ; CHECK:       latch.8:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    call void @bar()
-; CHECK-NEXT:    br i1 true, label [[LATCH_9:%.*]], label [[EXIT]]
+; CHECK-NEXT:    br label [[LATCH_9:%.*]]
 ; CHECK:       latch.9:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    call void @bar()
-; CHECK-NEXT:    br i1 false, label [[LATCH_10:%.*]], label [[EXIT]]
+; CHECK-NEXT:    br i1 false, label [[LATCH_10:%.*]], label [[EXIT:%.*]]
 ; CHECK:       latch.10:
 ; CHECK-NEXT:    call void @bar()
 ; CHECK-NEXT:    br label [[EXIT]]

diff  --git a/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll b/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll
index 3076084c53309..1cd86ec145df8 100644
--- a/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll
+++ b/llvm/test/Transforms/LoopUnroll/nonlatchcondbr.ll
@@ -168,7 +168,7 @@ define void @test3(i32* noalias %A, i1 %cond) {
 ; CHECK-NEXT:    call void @bar(i32 [[TMP0]])
 ; CHECK-NEXT:    br i1 [[COND:%.*]], label [[FOR_BODY:%.*]], label [[FOR_END:%.*]]
 ; CHECK:       for.body:
-; CHECK-NEXT:    br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE:%.*]], label [[FOR_END]]
+; CHECK-NEXT:    br label [[FOR_BODY_FOR_BODY_CRIT_EDGE:%.*]]
 ; CHECK:       for.body.for.body_crit_edge:
 ; CHECK-NEXT:    [[ARRAYIDX_PHI_TRANS_INSERT:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 1
 ; CHECK-NEXT:    [[DOTPRE:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT]], align 4
@@ -177,14 +177,14 @@ define void @test3(i32* noalias %A, i1 %cond) {
 ; CHECK:       for.end:
 ; CHECK-NEXT:    ret void
 ; CHECK:       for.body.1:
-; CHECK-NEXT:    br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_1:%.*]], label [[FOR_END]]
+; CHECK-NEXT:    br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_1:%.*]]
 ; CHECK:       for.body.for.body_crit_edge.1:
 ; CHECK-NEXT:    [[ARRAYIDX_PHI_TRANS_INSERT_1:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 2
 ; CHECK-NEXT:    [[DOTPRE_1:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_1]], align 4
 ; CHECK-NEXT:    call void @bar(i32 [[DOTPRE_1]])
 ; CHECK-NEXT:    br i1 [[COND]], label [[FOR_BODY_2:%.*]], label [[FOR_END]]
 ; CHECK:       for.body.2:
-; CHECK-NEXT:    br i1 true, label [[FOR_BODY_FOR_BODY_CRIT_EDGE_2:%.*]], label [[FOR_END]]
+; CHECK-NEXT:    br label [[FOR_BODY_FOR_BODY_CRIT_EDGE_2:%.*]]
 ; CHECK:       for.body.for.body_crit_edge.2:
 ; CHECK-NEXT:    [[ARRAYIDX_PHI_TRANS_INSERT_2:%.*]] = getelementptr inbounds i32, i32* [[A]], i64 3
 ; CHECK-NEXT:    [[DOTPRE_2:%.*]] = load i32, i32* [[ARRAYIDX_PHI_TRANS_INSERT_2]], align 4


        


More information about the llvm-commits mailing list