[llvm-commits] beginning unroll improvements
Dan Gohman
djg at cray.com
Wed May 9 09:50:30 PDT 2007
This patch extends the LoopUnroll pass to be able to unroll loops
with unknown trip counts. This is left off by default, and a
command-line option enables it. It also begins to separate loop
unrolling into a utility routine; eventually it might be made usable
from other passes.
It currently works by inserting conditional branches between each
unrolled iteration, unless it proves that the trip count is a
multiple of a constant integer > 1, which it currently only does in
the rare case that the trip count expression is a Mul operator with
a ConstantInt operand. Eventually this information might be provided
by other sources, for example by a pass that peels/splits the loop
for this purpose.
Dan
--
Dan Gohman, Cray Inc.
-------------- next part --------------
Index: test/Transforms/LoopUnroll/2007-05-09-UnknownTripCount.ll
===================================================================
RCS file: test/Transforms/LoopUnroll/2007-05-09-UnknownTripCount.ll
diff -N test/Transforms/LoopUnroll/2007-05-09-UnknownTripCount.ll
--- /dev/null
+++ test/Transforms/LoopUnroll/2007-05-09-UnknownTripCount.ll
@@ -0,0 +1,18 @@
+; RUN: llvm-as < %s | opt -loop-unroll -unroll-count=3 | llvm-dis | grep bb72.2
+
+define void @foo(i32 %trips) {
+entry:
+ br label %cond_true.outer
+
+cond_true.outer:
+ %indvar1.ph = phi i32 [ 0, %entry ], [ %indvar.next2, %bb72 ]
+ br label %bb72
+
+bb72:
+ %indvar.next2 = add i32 %indvar1.ph, 1
+ %exitcond3 = icmp eq i32 %indvar.next2, %trips
+ br i1 %exitcond3, label %cond_true138, label %cond_true.outer
+
+cond_true138:
+ ret void
+}
Index: lib/Transforms/Scalar/LoopUnroll.cpp
===================================================================
RCS file: /var/cvs/llvm/llvm/lib/Transforms/Scalar/LoopUnroll.cpp,v
retrieving revision 1.46
diff -u -d -r1.46 LoopUnroll.cpp
--- lib/Transforms/Scalar/LoopUnroll.cpp
+++ lib/Transforms/Scalar/LoopUnroll.cpp
@@ -31,6 +31,7 @@
#include "llvm/Support/Compiler.h"
#include "llvm/Support/CommandLine.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
#include "llvm/ADT/Statistic.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallPtrSet.h"
@@ -39,12 +40,19 @@
#include <algorithm>
using namespace llvm;
-STATISTIC(NumUnrolled, "Number of loops completely unrolled");
+STATISTIC(NumCompletelyUnrolled, "Number of loops completely unrolled");
+STATISTIC(NumUnrolled, "Number of loops unrolled (completely or otherwise)");
namespace {
cl::opt<unsigned>
- UnrollThreshold("unroll-threshold", cl::init(100), cl::Hidden,
- cl::desc("The cut-off point for loop unrolling"));
+ UnrollThreshold
+ ("unroll-threshold", cl::init(100), cl::Hidden,
+ cl::desc("The cut-off point for automatic loop unrolling"));
+
+ cl::opt<unsigned>
+ UnrollCount
+ ("unroll-count", cl::init(0), cl::Hidden,
+ cl::desc("Use this unroll count for all loops, for testing purposes"));
class VISIBILITY_HIDDEN LoopUnroll : public LoopPass {
LoopInfo *LI; // The current loop information
@@ -52,7 +60,13 @@
static char ID; // Pass ID, replacement for typeid
LoopUnroll() : LoopPass((intptr_t)&ID) {}
+ /// A magic value for use with the Threshold parameter to indicate
+ /// that the loop unroll should be performed regardless of how much
+ /// code expansion would result.
+ static const unsigned NoThreshold = UINT_MAX;
+
bool runOnLoop(Loop *L, LPPassManager &LPM);
+ bool unrollLoop(Loop *L, unsigned Count, unsigned Threshold);
BasicBlock *FoldBlockIntoPredecessor(BasicBlock *BB);
/// This transformation requires natural loop information & requires that
@@ -162,43 +176,137 @@
}
bool LoopUnroll::runOnLoop(Loop *L, LPPassManager &LPM) {
- bool Changed = false;
LI = &getAnalysis<LoopInfo>();
+ // Unroll the loop.
+ if (!unrollLoop(L, UnrollCount, UnrollThreshold))
+ return false;
+
+ // Update the loop information for this loop.
+ // If we completely unrolled the loop, remove it from the parent.
+ if (L->getNumBackEdges() == 0)
+ LPM.deleteLoopFromQueue(L);
+
+ return true;
+}
+
+/// Unroll the given loop by UnrollCount, or by a heuristically-determined
+/// value if Count is zero. If Threshold is non-NULL, it points to
+/// a Threshold value to limit code size expansion. If the loop size would
+/// expand beyond the threshold value, unrolling is suppressed. The return
+/// value is false if no transformations are performed.
+///
+bool LoopUnroll::unrollLoop(Loop *L, unsigned Count, unsigned Threshold) {
+ assert(L->isLCSSAForm());
+
BasicBlock *Header = L->getHeader();
BasicBlock *LatchBlock = L->getLoopLatch();
-
BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
- if (BI == 0) return Changed; // Must end in a conditional branch
- ConstantInt *TripCountC = dyn_cast_or_null<ConstantInt>(L->getTripCount());
- if (!TripCountC) return Changed; // Must have constant trip count!
+ DOUT << "Loop Unroll: F[" << Header->getParent()->getName()
+ << "] Loop %" << Header->getName() << "\n";
- // Guard against huge trip counts. This also guards against assertions in
- // APInt from the use of getZExtValue, below.
- if (TripCountC->getValue().getActiveBits() > 32)
- return Changed; // More than 2^32 iterations???
+ if (!BI || BI->isUnconditional()) {
+ // The loop-rorate pass can be helpful to avoid this in many cases.
+ DOUT << " Can't unroll; loop not terminated by a conditional branch.\n";
+ return false;
+ }
- uint64_t TripCountFull = TripCountC->getZExtValue();
- if (TripCountFull == 0)
- return Changed; // Zero iteraitons?
+ // Determine the trip count and/or trip multiple. A TripCount value of zero
+ // is used to mean an unknown trip count. The TripMultiple value is the
+ // greatest known integer multiple of the trip count.
+ unsigned TripCount = 0;
+ unsigned TripMultiple = 1;
+ if (Value *TripCountValue = L->getTripCount()) {
+ if (ConstantInt *TripCountC = dyn_cast<ConstantInt>(TripCountValue)) {
+ // Guard against huge trip counts. This also guards against assertions in
+ // APInt from the use of getZExtValue, below.
+ if (TripCountC->getValue().getActiveBits() <= 32) {
+ TripCount = (unsigned)TripCountC->getZExtValue();
+ }
+ } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(TripCountValue)) {
+ switch (BO->getOpcode()) {
+ case BinaryOperator::Mul:
+ if (ConstantInt *MultipleC = dyn_cast<ConstantInt>(BO->getOperand(1))) {
+ if (MultipleC->getValue().getActiveBits() <= 32) {
+ TripMultiple = (unsigned)MultipleC->getZExtValue();
+ }
+ }
+ break;
+ default: break;
+ }
+ }
+ }
+ if (TripCount != 0)
+ DOUT << " Trip Count = " << TripCount << "\n";
+ if (TripMultiple != 1)
+ DOUT << " Trip Multiple = " << TripMultiple << "\n";
- unsigned LoopSize = ApproximateLoopSize(L);
- DOUT << "Loop Unroll: F[" << Header->getParent()->getName()
- << "] Loop %" << Header->getName() << " Loop Size = "
- << LoopSize << " Trip Count = " << TripCountFull << " - ";
- uint64_t Size = (uint64_t)LoopSize*TripCountFull;
- if (Size > UnrollThreshold) {
- DOUT << "TOO LARGE: " << Size << ">" << UnrollThreshold << "\n";
- return Changed;
+ // Automatically select an unroll count.
+ if (Count == 0) {
+ // Conservative heuristic: if we know the trip count, see if we can
+ // completely unroll (subject to the threshold, checked below); otherwise
+ // don't unroll.
+ if (TripCount != 0) {
+ Count = TripCount;
+ } else {
+ return false;
+ }
}
- DOUT << "UNROLLING!\n";
- std::vector<BasicBlock*> LoopBlocks = L->getBlocks();
+ // Effectively "DCE" unrolled iterations that are beyond the tripcount
+ // and will never be executed.
+ if (TripCount != 0 && Count > TripCount)
+ Count = TripCount;
- unsigned TripCount = (unsigned)TripCountFull;
+ assert(Count > 0);
+ assert(TripMultiple > 0);
+ assert(TripCount == 0 || TripCount % TripMultiple == 0);
- BasicBlock *LoopExit = BI->getSuccessor(L->contains(BI->getSuccessor(0)));
+ // Enforce the threshold.
+ if (Threshold != NoThreshold) {
+ unsigned LoopSize = ApproximateLoopSize(L);
+ DOUT << " Loop Size = " << LoopSize << "\n";
+ uint64_t Size = (uint64_t)LoopSize*Count;
+ if (TripCount != 1 && Size > Threshold) {
+ DOUT << " TOO LARGE TO UNROLL: "
+ << Size << ">" << Threshold << "\n";
+ return false;
+ }
+ }
+
+ // Are we eliminating the loop control altogether?
+ bool CompletelyUnroll = Count == TripCount;
+
+ // If we know the trip count, we know the multiple...
+ unsigned BreakoutTrip = 0;
+ if (TripCount != 0) {
+ BreakoutTrip = TripCount % Count;
+ TripMultiple = 0;
+ } else {
+ // Figure out what multiple to use.
+ BreakoutTrip = TripMultiple =
+ (unsigned)GreatestCommonDivisor64(Count, TripMultiple);
+ }
+
+ if (CompletelyUnroll) {
+ DOUT << "COMPLETELY UNROLLING loop %" << Header->getName()
+ << " with trip count " << TripCount << "!\n";
+ } else {
+ DOUT << "UNROLLING loop %" << Header->getName()
+ << " by " << Count;
+ if (TripMultiple == 0 || BreakoutTrip != TripMultiple) {
+ DOUT << " with a breakout at trip " << BreakoutTrip;
+ } else if (TripMultiple != 1) {
+ DOUT << " with " << TripMultiple << " trips per branch";
+ }
+ DOUT << "!\n";
+ }
+
+ std::vector<BasicBlock*> LoopBlocks = L->getBlocks();
+
+ bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
+ BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
// For the first iteration of the loop, we should use the precloned values for
// PHI nodes. Insert associations now.
@@ -214,16 +322,12 @@
LastValueMap[I] = I;
}
- // Remove the exit branch from the loop
- LatchBlock->getInstList().erase(BI);
-
std::vector<BasicBlock*> Headers;
std::vector<BasicBlock*> Latches;
Headers.push_back(Header);
Latches.push_back(LatchBlock);
- assert(TripCount != 0 && "Trip count of 0 is impossible!");
- for (unsigned It = 1; It != TripCount; ++It) {
+ for (unsigned It = 1; It != Count; ++It) {
char SuffixBuffer[100];
sprintf(SuffixBuffer, ".%d", It);
@@ -277,9 +381,18 @@
// we can insert the proper branches later.
if (*BB == Header)
Headers.push_back(New);
- if (*BB == LatchBlock)
+ if (*BB == LatchBlock) {
Latches.push_back(New);
+ // Also, clear out the new latch's back edge so that it doesn't look
+ // like a new loop, so that it's amenable to being merged with adjacent
+ // blocks later on.
+ TerminatorInst *Term = New->getTerminator();
+ assert(L->contains(Term->getSuccessor(!ContinueOnTrue)));
+ assert(Term->getSuccessor(ContinueOnTrue) == LoopExit);
+ Term->setSuccessor(!ContinueOnTrue, NULL);
+ }
+
NewBlocks.push_back(New);
}
@@ -289,12 +402,11 @@
E = NewBlocks[i]->end(); I != E; ++I)
RemapInstruction(I, LastValueMap);
}
-
// The latch block exits the loop. If there are any PHI nodes in the
// successor blocks, update them to use the appropriate values computed as the
// last iteration of the loop.
- if (TripCount > 1) {
+ if (Count != 1) {
SmallPtrSet<PHINode*, 8> Users;
for (Value::use_iterator UI = LatchBlock->use_begin(),
UE = LatchBlock->use_end(); UI != UE; ++UI)
@@ -316,29 +428,55 @@
}
}
- // Now loop over the PHI nodes in the original block, setting them to their
- // incoming values.
- BasicBlock *Preheader = L->getLoopPreheader();
- for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) {
- PHINode *PN = OrigPHINode[i];
- PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader));
- Header->getInstList().erase(PN);
+ // Now, if we're doing complete unrolling, loop over the PHI nodes in the
+ // original block, setting them to their incoming values.
+ if (CompletelyUnroll) {
+ BasicBlock *Preheader = L->getLoopPreheader();
+ for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) {
+ PHINode *PN = OrigPHINode[i];
+ PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader));
+ Header->getInstList().erase(PN);
+ }
}
-
- // Insert the branches that link the different iterations together
- for (unsigned i = 0; i < Latches.size()-1; ++i) {
- new BranchInst(Headers[i+1], Latches[i]);
- if (BasicBlock *Fold = FoldBlockIntoPredecessor(Headers[i+1])) {
- std::replace(Latches.begin(), Latches.end(), Headers[i+1], Fold);
- std::replace(Headers.begin(), Headers.end(), Headers[i+1], Fold);
+
+ // Now that all the basic blocks for the unrolled iterations are in place,
+ // set up the branches to connect them.
+ for (unsigned i = 0, e = Latches.size(); i != e; ++i) {
+ // The original branch was replicated in each unrolled iteration.
+ BranchInst *Term = cast<BranchInst>(Latches[i]->getTerminator());
+
+ // The branch destination.
+ unsigned j = (i + 1) % e;
+ BasicBlock *Dest = Headers[j];
+ bool NeedConditional = true;
+
+ // For a complete unroll, make the last iteration end with a branch
+ // to the exit block.
+ if (CompletelyUnroll && j == 0) {
+ Dest = LoopExit;
+ NeedConditional = false;
+ }
+
+ // If we know the trip count or a multiple of it, we can safely use an
+ // unconditional branch for some iterations.
+ if (j != BreakoutTrip && (TripMultiple == 0 || j % TripMultiple != 0)) {
+ NeedConditional = false;
+ }
+
+ if (NeedConditional) {
+ // Update the conditional branch's successor for the following
+ // iteration.
+ Term->setSuccessor(!ContinueOnTrue, Dest);
+ } else {
+ Term->setUnconditionalDest(Dest);
+ // Merge adjacent basic blocks, if possible.
+ if (BasicBlock *Fold = FoldBlockIntoPredecessor(Dest)) {
+ std::replace(Latches.begin(), Latches.end(), Dest, Fold);
+ std::replace(Headers.begin(), Headers.end(), Dest, Fold);
+ }
}
}
- // Finally, add an unconditional branch to the block to continue into the exit
- // block.
- new BranchInst(LoopExit, Latches[Latches.size()-1]);
- FoldBlockIntoPredecessor(LoopExit);
-
// At this point, the code is well formed. We now do a quick sweep over the
// inserted code, doing constant propagation and dead code elimination as we
// go.
@@ -356,10 +494,7 @@
}
}
- // Update the loop information for this loop.
- // Remove the loop from the parent.
- LPM.deleteLoopFromQueue(L);
-
+ NumCompletelyUnrolled += CompletelyUnroll;
++NumUnrolled;
return true;
}
More information about the llvm-commits
mailing list