[llvm-commits] beginning unroll improvements

Dan Gohman djg at cray.com
Wed May 9 09:50:30 PDT 2007


This patch extends the LoopUnroll pass to be able to unroll loops
with unknown trip counts. This is left off by default, and a
command-line option enables it. It also begins to separate loop
unrolling into a utility routine; eventually it might be made usable
from other passes.

It currently works by inserting conditional branches between each
unrolled iteration, unless it proves that the trip count is a
multiple of a constant integer > 1, which it currently only does in
the rare case that the trip count expression is a Mul operator with
a ConstantInt operand. Eventually this information might be provided
by other sources, for example by a pass that peels/splits the loop
for this purpose.

Dan

-- 
Dan Gohman, Cray Inc.
-------------- next part --------------
Index: test/Transforms/LoopUnroll/2007-05-09-UnknownTripCount.ll
===================================================================
RCS file: test/Transforms/LoopUnroll/2007-05-09-UnknownTripCount.ll
diff -N test/Transforms/LoopUnroll/2007-05-09-UnknownTripCount.ll
--- /dev/null
+++ test/Transforms/LoopUnroll/2007-05-09-UnknownTripCount.ll
@@ -0,0 +1,18 @@
+; RUN: llvm-as < %s | opt -loop-unroll -unroll-count=3 | llvm-dis | grep bb72.2
+
+define void @foo(i32 %trips) {
+entry:
+	br label %cond_true.outer
+
+cond_true.outer:
+	%indvar1.ph = phi i32 [ 0, %entry ], [ %indvar.next2, %bb72 ]
+	br label %bb72
+
+bb72:
+	%indvar.next2 = add i32 %indvar1.ph, 1
+	%exitcond3 = icmp eq i32 %indvar.next2, %trips
+	br i1 %exitcond3, label %cond_true138, label %cond_true.outer
+
+cond_true138:
+	ret void
+}
Index: lib/Transforms/Scalar/LoopUnroll.cpp
===================================================================
RCS file: /var/cvs/llvm/llvm/lib/Transforms/Scalar/LoopUnroll.cpp,v
retrieving revision 1.46
diff -u -d -r1.46 LoopUnroll.cpp
--- lib/Transforms/Scalar/LoopUnroll.cpp
+++ lib/Transforms/Scalar/LoopUnroll.cpp
@@ -31,6 +31,7 @@
 #include "llvm/Support/Compiler.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -39,12 +40,19 @@
 #include <algorithm>
 using namespace llvm;
 
-STATISTIC(NumUnrolled, "Number of loops completely unrolled");
+STATISTIC(NumCompletelyUnrolled, "Number of loops completely unrolled");
+STATISTIC(NumUnrolled,           "Number of loops unrolled (completely or otherwise)");
 
 namespace {
   cl::opt<unsigned>
-  UnrollThreshold("unroll-threshold", cl::init(100), cl::Hidden,
-                  cl::desc("The cut-off point for loop unrolling"));
+  UnrollThreshold
+    ("unroll-threshold", cl::init(100), cl::Hidden,
+     cl::desc("The cut-off point for automatic loop unrolling"));
+
+  cl::opt<unsigned>
+  UnrollCount
+    ("unroll-count", cl::init(0), cl::Hidden,
+     cl::desc("Use this unroll count for all loops, for testing purposes"));
 
   class VISIBILITY_HIDDEN LoopUnroll : public LoopPass {
     LoopInfo *LI;  // The current loop information
@@ -52,7 +60,13 @@
     static char ID; // Pass ID, replacement for typeid
     LoopUnroll() : LoopPass((intptr_t)&ID) {}
 
+    /// A magic value for use with the Threshold parameter to indicate
+    /// that the loop unroll should be performed regardless of how much
+    /// code expansion would result.
+    static const unsigned NoThreshold = UINT_MAX;
+
     bool runOnLoop(Loop *L, LPPassManager &LPM);
+    bool unrollLoop(Loop *L, unsigned Count, unsigned Threshold);
     BasicBlock *FoldBlockIntoPredecessor(BasicBlock *BB);
 
     /// This transformation requires natural loop information & requires that
@@ -162,43 +176,137 @@
 }
 
 bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
-  bool Changed = false;
   LI = &getAnalysis<LoopInfo>();
 
+  // Unroll the loop.
+  if (!unrollLoop(L, UnrollCount, UnrollThreshold))
+    return false;
+
+  // Update the loop information for this loop.
+  // If we completely unrolled the loop, remove it from the parent.
+  if (L->getNumBackEdges() == 0)
+    LPM.deleteLoopFromQueue(L);
+
+  return true;
+}
+
+/// Unroll the given loop by UnrollCount, or by a heuristically-determined
+/// value if Count is zero. If Threshold is non-NULL, it points to
+/// a Threshold value to limit code size expansion. If the loop size would
+/// expand beyond the threshold value, unrolling is suppressed. The return
+/// value is false if no transformations are performed.
+///
+bool LoopUnroll::unrollLoop(Loop *L, unsigned Count, unsigned Threshold) {
+  assert(L->isLCSSAForm());
+
   BasicBlock *Header = L->getHeader();
   BasicBlock *LatchBlock = L->getLoopLatch();
-
   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
-  if (BI == 0) return Changed;  // Must end in a conditional branch
 
-  ConstantInt *TripCountC = dyn_cast_or_null<ConstantInt>(L->getTripCount());
-  if (!TripCountC) return Changed;  // Must have constant trip count!
+  DOUT << "Loop Unroll: F[" << Header->getParent()->getName()
+       << "] Loop %" << Header->getName() << "\n";
 
-  // Guard against huge trip counts. This also guards against assertions in
-  // APInt from the use of getZExtValue, below.
-  if (TripCountC->getValue().getActiveBits() > 32)
-    return Changed; // More than 2^32 iterations???
+  if (!BI || BI->isUnconditional()) {
+    // The loop-rorate pass can be helpful to avoid this in many cases.
+    DOUT << "  Can't unroll; loop not terminated by a conditional branch.\n";
+    return false;
+  }
 
-  uint64_t TripCountFull = TripCountC->getZExtValue();
-  if (TripCountFull == 0)
-    return Changed; // Zero iteraitons?
+  // Determine the trip count and/or trip multiple. A TripCount value of zero
+  // is used to mean an unknown trip count. The TripMultiple value is the
+  // greatest known integer multiple of the trip count.
+  unsigned TripCount = 0;
+  unsigned TripMultiple = 1;
+  if (Value *TripCountValue = L->getTripCount()) {
+    if (ConstantInt *TripCountC = dyn_cast<ConstantInt>(TripCountValue)) {
+      // Guard against huge trip counts. This also guards against assertions in
+      // APInt from the use of getZExtValue, below.
+      if (TripCountC->getValue().getActiveBits() <= 32) {
+        TripCount = (unsigned)TripCountC->getZExtValue();
+      }
+    } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(TripCountValue)) {
+      switch (BO->getOpcode()) {
+      case BinaryOperator::Mul:
+        if (ConstantInt *MultipleC = dyn_cast<ConstantInt>(BO->getOperand(1))) {
+          if (MultipleC->getValue().getActiveBits() <= 32) {
+            TripMultiple = (unsigned)MultipleC->getZExtValue();
+          }
+        }
+        break;
+      default: break;
+      }
+    }
+  }
+  if (TripCount != 0)
+    DOUT << "  Trip Count = " << TripCount << "\n";
+  if (TripMultiple != 1)
+    DOUT << "  Trip Multiple = " << TripMultiple << "\n";
 
-  unsigned LoopSize = ApproximateLoopSize(L);
-  DOUT << "Loop Unroll: F[" << Header->getParent()->getName()
-       << "] Loop %" << Header->getName() << " Loop Size = "
-       << LoopSize << " Trip Count = " << TripCountFull << " - ";
-  uint64_t Size = (uint64_t)LoopSize*TripCountFull;
-  if (Size > UnrollThreshold) {
-    DOUT << "TOO LARGE: " << Size << ">" << UnrollThreshold << "\n";
-    return Changed;
+  // Automatically select an unroll count.
+  if (Count == 0) {
+    // Conservative heuristic: if we know the trip count, see if we can
+    // completely unroll (subject to the threshold, checked below); otherwise
+    // don't unroll.
+    if (TripCount != 0) {
+      Count = TripCount;
+    } else {
+      return false;
+    }
   }
-  DOUT << "UNROLLING!\n";
 
-  std::vector<BasicBlock*> LoopBlocks = L->getBlocks();
+  // Effectively "DCE" unrolled iterations that are beyond the tripcount
+  // and will never be executed.
+  if (TripCount != 0 && Count > TripCount)
+    Count = TripCount;
 
-  unsigned TripCount = (unsigned)TripCountFull;
+  assert(Count > 0);
+  assert(TripMultiple > 0);
+  assert(TripCount == 0 || TripCount % TripMultiple == 0);
 
-  BasicBlock *LoopExit = BI->getSuccessor(L->contains(BI->getSuccessor(0))); 
+  // Enforce the threshold.
+  if (Threshold != NoThreshold) {
+    unsigned LoopSize = ApproximateLoopSize(L);
+    DOUT << "  Loop Size = " << LoopSize << "\n";
+    uint64_t Size = (uint64_t)LoopSize*Count;
+    if (TripCount != 1 && Size > Threshold) {
+      DOUT << "  TOO LARGE TO UNROLL: "
+           << Size << ">" << Threshold << "\n";
+      return false;
+    }
+  }
+
+  // Are we eliminating the loop control altogether?
+  bool CompletelyUnroll = Count == TripCount;
+
+  // If we know the trip count, we know the multiple...
+  unsigned BreakoutTrip = 0;
+  if (TripCount != 0) {
+    BreakoutTrip = TripCount % Count;
+    TripMultiple = 0;
+  } else {
+    // Figure out what multiple to use.
+    BreakoutTrip = TripMultiple =
+      (unsigned)GreatestCommonDivisor64(Count, TripMultiple);
+  }
+
+  if (CompletelyUnroll) {
+    DOUT << "COMPLETELY UNROLLING loop %" << Header->getName()
+         << " with trip count " << TripCount << "!\n";
+  } else {
+    DOUT << "UNROLLING loop %" << Header->getName()
+         << " by " << Count;
+    if (TripMultiple == 0 || BreakoutTrip != TripMultiple) {
+      DOUT << " with a breakout at trip " << BreakoutTrip;
+    } else if (TripMultiple != 1) {
+      DOUT << " with " << TripMultiple << " trips per branch";
+    }
+    DOUT << "!\n";
+  }
+
+  std::vector<BasicBlock*> LoopBlocks = L->getBlocks();
+
+  bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
+  BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
 
   // For the first iteration of the loop, we should use the precloned values for
   // PHI nodes.  Insert associations now.
@@ -214,16 +322,12 @@
         LastValueMap[I] = I;
   }
 
-  // Remove the exit branch from the loop
-  LatchBlock->getInstList().erase(BI);
-  
   std::vector<BasicBlock*> Headers;
   std::vector<BasicBlock*> Latches;
   Headers.push_back(Header);
   Latches.push_back(LatchBlock);
 
-  assert(TripCount != 0 && "Trip count of 0 is impossible!");
-  for (unsigned It = 1; It != TripCount; ++It) {
+  for (unsigned It = 1; It != Count; ++It) {
     char SuffixBuffer[100];
     sprintf(SuffixBuffer, ".%d", It);
     
@@ -277,9 +381,18 @@
       // we can insert the proper branches later.
       if (*BB == Header)
         Headers.push_back(New);
-      if (*BB == LatchBlock)
+      if (*BB == LatchBlock) {
         Latches.push_back(New);
 
+        // Also, clear out the new latch's back edge so that it doesn't look
+        // like a new loop, so that it's amenable to being merged with adjacent
+        // blocks later on.
+        TerminatorInst *Term = New->getTerminator();
+        assert(L->contains(Term->getSuccessor(!ContinueOnTrue)));
+        assert(Term->getSuccessor(ContinueOnTrue) == LoopExit);
+        Term->setSuccessor(!ContinueOnTrue, NULL);
+      }
+
       NewBlocks.push_back(New);
     }
     
@@ -289,12 +402,11 @@
            E = NewBlocks[i]->end(); I != E; ++I)
         RemapInstruction(I, LastValueMap);
   }
-
   
   // The latch block exits the loop.  If there are any PHI nodes in the
   // successor blocks, update them to use the appropriate values computed as the
   // last iteration of the loop.
-  if (TripCount > 1) {
+  if (Count != 1) {
     SmallPtrSet<PHINode*, 8> Users;
     for (Value::use_iterator UI = LatchBlock->use_begin(),
          UE = LatchBlock->use_end(); UI != UE; ++UI)
@@ -316,29 +428,55 @@
     }
   }
 
-  // Now loop over the PHI nodes in the original block, setting them to their
-  // incoming values.
-  BasicBlock *Preheader = L->getLoopPreheader();
-  for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) {
-    PHINode *PN = OrigPHINode[i];
-    PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader));
-    Header->getInstList().erase(PN);
+  // Now, if we're doing complete unrolling, loop over the PHI nodes in the
+  // original block, setting them to their incoming values.
+  if (CompletelyUnroll) {
+    BasicBlock *Preheader = L->getLoopPreheader();
+    for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) {
+      PHINode *PN = OrigPHINode[i];
+      PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader));
+      Header->getInstList().erase(PN);
+    }
   }
-  
-  //  Insert the branches that link the different iterations together
-  for (unsigned i = 0; i < Latches.size()-1; ++i) {
-    new BranchInst(Headers[i+1], Latches[i]);
-    if (BasicBlock *Fold = FoldBlockIntoPredecessor(Headers[i+1])) {
-      std::replace(Latches.begin(), Latches.end(), Headers[i+1], Fold);
-      std::replace(Headers.begin(), Headers.end(), Headers[i+1], Fold);
+
+  // Now that all the basic blocks for the unrolled iterations are in place,
+  // set up the branches to connect them.
+  for (unsigned i = 0, e = Latches.size(); i != e; ++i) {
+    // The original branch was replicated in each unrolled iteration.
+    BranchInst *Term = cast<BranchInst>(Latches[i]->getTerminator());
+
+    // The branch destination.
+    unsigned j = (i + 1) % e;
+    BasicBlock *Dest = Headers[j];
+    bool NeedConditional = true;
+
+    // For a complete unroll, make the last iteration end with a branch
+    // to the exit block.
+    if (CompletelyUnroll && j == 0) {
+      Dest = LoopExit;
+      NeedConditional = false;
+    }
+
+    // If we know the trip count or a multiple of it, we can safely use an
+    // unconditional branch for some iterations.
+    if (j != BreakoutTrip && (TripMultiple == 0 || j % TripMultiple != 0)) {
+      NeedConditional = false;
+    }
+
+    if (NeedConditional) {
+      // Update the conditional branch's successor for the following
+      // iteration.
+      Term->setSuccessor(!ContinueOnTrue, Dest);
+    } else {
+      Term->setUnconditionalDest(Dest);
+      // Merge adjacent basic blocks, if possible.
+      if (BasicBlock *Fold = FoldBlockIntoPredecessor(Dest)) {
+        std::replace(Latches.begin(), Latches.end(), Dest, Fold);
+        std::replace(Headers.begin(), Headers.end(), Dest, Fold);
+      }
     }
   }
   
-  // Finally, add an unconditional branch to the block to continue into the exit
-  // block.
-  new BranchInst(LoopExit, Latches[Latches.size()-1]);
-  FoldBlockIntoPredecessor(LoopExit);
-  
   // At this point, the code is well formed.  We now do a quick sweep over the
   // inserted code, doing constant propagation and dead code elimination as we
   // go.
@@ -356,10 +494,7 @@
       }
     }
 
-  // Update the loop information for this loop.
-  // Remove the loop from the parent.
-  LPM.deleteLoopFromQueue(L);
-
+  NumCompletelyUnrolled += CompletelyUnroll;
   ++NumUnrolled;
   return true;
 }


More information about the llvm-commits mailing list