[llvm-commits] [llvm] r52645 - in /llvm/trunk: lib/Transforms/Utils/UnrollLoop.cpp test/Transforms/LoopUnroll/multiple-phis.ll test/Transforms/LoopUnroll/pr2253.ll

Dan Gohman gohman at apple.com
Mon Jun 23 14:29:42 PDT 2008


Author: djg
Date: Mon Jun 23 16:29:41 2008
New Revision: 52645

URL: http://llvm.org/viewvc/llvm-project?rev=52645&view=rev
Log:
Revamp the loop unroller, extending it to correctly update PHI nodes
in the presence of out-of-loop users of in-loop values and the trip
count is not a known multiple of the unroll count, and to be a bit
simpler overall. This fixes PR2253.

Added:
    llvm/trunk/test/Transforms/LoopUnroll/multiple-phis.ll
    llvm/trunk/test/Transforms/LoopUnroll/pr2253.ll
Modified:
    llvm/trunk/lib/Transforms/Utils/UnrollLoop.cpp

Modified: llvm/trunk/lib/Transforms/Utils/UnrollLoop.cpp
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/lib/Transforms/Utils/UnrollLoop.cpp?rev=52645&r1=52644&r2=52645&view=diff

==============================================================================
--- llvm/trunk/lib/Transforms/Utils/UnrollLoop.cpp (original)
+++ llvm/trunk/lib/Transforms/Utils/UnrollLoop.cpp Mon Jun 23 16:29:41 2008
@@ -22,6 +22,7 @@
 #include "llvm/Transforms/Utils/UnrollLoop.h"
 #include "llvm/BasicBlock.h"
 #include "llvm/ADT/Statistic.h"
+#include "llvm/ADT/STLExtras.h"
 #include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/LoopPass.h"
 #include "llvm/Support/Debug.h"
@@ -106,13 +107,17 @@
 ///
 /// If a LoopPassManager is passed in, and the loop is fully removed, it will be
 /// removed from the LoopPassManager as well. LPM can also be NULL.
-bool llvm::UnrollLoop(Loop *L, unsigned Count, LoopInfo* LI, LPPassManager* LPM) {
+bool llvm::UnrollLoop(Loop *L, unsigned Count, LoopInfo* LI,
+                      LPPassManager* LPM) {
   assert(L->isLCSSAForm());
 
   BasicBlock *Header = L->getHeader();
   BasicBlock *LatchBlock = L->getLoopLatch();
   BranchInst *BI = dyn_cast<BranchInst>(LatchBlock->getTerminator());
-  
+
+  Function *Func = Header->getParent();
+  Function::iterator BBInsertPt = next(Function::iterator(LatchBlock));
+
   if (!BI || BI->isUnconditional()) {
     // The loop-rotate pass can be helpful to avoid this in many cases.
     DOUT << "  Can't unroll; loop not terminated by a conditional branch.\n";
@@ -168,162 +173,148 @@
     DOUT << "!\n";
   }
 
+  // Make a copy of the original LoopBlocks list so we can keep referring
+  // to it while hacking on the loop.
   std::vector<BasicBlock*> LoopBlocks = L->getBlocks();
 
-  bool ContinueOnTrue = L->contains(BI->getSuccessor(0));
+  bool ContinueOnTrue = BI->getSuccessor(0) == Header;
   BasicBlock *LoopExit = BI->getSuccessor(ContinueOnTrue);
 
   // For the first iteration of the loop, we should use the precloned values for
   // PHI nodes.  Insert associations now.
   typedef DenseMap<const Value*, Value*> ValueMapTy;
   ValueMapTy LastValueMap;
-  std::vector<PHINode*> OrigPHINode;
   for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) {
     PHINode *PN = cast<PHINode>(I);
-    OrigPHINode.push_back(PN);
     if (Instruction *I = 
                 dyn_cast<Instruction>(PN->getIncomingValueForBlock(LatchBlock)))
       if (L->contains(I->getParent()))
         LastValueMap[I] = I;
   }
 
+  // Keep track of all the headers and latches that we create. These are
+  // needed by the logic that inserts the branches to connect all the
+  // new blocks.
   std::vector<BasicBlock*> Headers;
   std::vector<BasicBlock*> Latches;
+  Headers.reserve(Count);
+  Latches.reserve(Count);
   Headers.push_back(Header);
   Latches.push_back(LatchBlock);
 
+  // Iterate through all but the first iterations, cloning blocks from
+  // the first iteration to populate the subsequent iterations.
   for (unsigned It = 1; It != Count; ++It) {
     char SuffixBuffer[100];
     sprintf(SuffixBuffer, ".%d", It);
     
     std::vector<BasicBlock*> NewBlocks;
+    NewBlocks.reserve(LoopBlocks.size());
     
-    for (std::vector<BasicBlock*>::iterator BB = LoopBlocks.begin(),
-         E = LoopBlocks.end(); BB != E; ++BB) {
+    // Iterate through all the blocks in the original loop.
+    for (std::vector<BasicBlock*>::const_iterator BBI = LoopBlocks.begin(),
+         E = LoopBlocks.end(); BBI != E; ++BBI) {
+      bool SuppressExitEdges = false;
+      BasicBlock *BB = *BBI;
       ValueMapTy ValueMap;
-      BasicBlock *New = CloneBasicBlock(*BB, ValueMap, SuffixBuffer);
-      Header->getParent()->getBasicBlockList().push_back(New);
+      BasicBlock *New = CloneBasicBlock(BB, ValueMap, SuffixBuffer);
+      NewBlocks.push_back(New);
+      Func->getBasicBlockList().insert(BBInsertPt, New);
+      L->addBasicBlockToLoop(New, LI->getBase());
 
-      // Loop over all of the PHI nodes in the block, changing them to use the
-      // incoming values from the previous block.
-      if (*BB == Header)
-        for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) {
-          PHINode *NewPHI = cast<PHINode>(ValueMap[OrigPHINode[i]]);
+      // Special handling for the loop header block.
+      if (BB == Header) {
+        // Keep track of new headers as we create them, so that we can insert
+        // the proper branches later.
+        Headers[It] = New;
+
+        // Loop over all of the PHI nodes in the block, changing them to use
+        // the incoming values from the previous block.
+        for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) {
+          PHINode *NewPHI = cast<PHINode>(ValueMap[I]);
           Value *InVal = NewPHI->getIncomingValueForBlock(LatchBlock);
           if (Instruction *InValI = dyn_cast<Instruction>(InVal))
             if (It > 1 && L->contains(InValI->getParent()))
               InVal = LastValueMap[InValI];
-          ValueMap[OrigPHINode[i]] = InVal;
+          ValueMap[I] = InVal;
           New->getInstList().erase(NewPHI);
         }
+      }
+
+      // Special handling for the loop latch block.
+      if (BB == LatchBlock) {
+        // Keep track of new latches as we create them, so that we can insert
+        // the proper branches later.
+        Latches[It] = New;
+
+        // If knowledge of the trip count and/or multiple will allow us
+        // to emit unconditional branches in some of the new latch blocks,
+        // those blocks shouldn't be referenced by PHIs that reference
+        // the original latch.
+        unsigned NextIt = (It + 1) % Count;
+        SuppressExitEdges =
+          NextIt != BreakoutTrip &&
+          (TripMultiple == 0 || NextIt % TripMultiple != 0);
+      }
 
       // Update our running map of newest clones
-      LastValueMap[*BB] = New;
+      LastValueMap[BB] = New;
       for (ValueMapTy::iterator VI = ValueMap.begin(), VE = ValueMap.end();
            VI != VE; ++VI)
         LastValueMap[VI->first] = VI->second;
 
-      L->addBasicBlockToLoop(New, LI->getBase());
-
-      // Add phi entries for newly created values to all exit blocks except
-      // the successor of the latch block.  The successor of the exit block will
-      // be updated specially after unrolling all the way.
-      if (*BB != LatchBlock)
-        for (Value::use_iterator UI = (*BB)->use_begin(), UE = (*BB)->use_end();
-             UI != UE;) {
-          Instruction *UseInst = cast<Instruction>(*UI);
-          ++UI;
-          if (isa<PHINode>(UseInst) && !L->contains(UseInst->getParent())) {
-            PHINode *phi = cast<PHINode>(UseInst);
-            Value *Incoming = phi->getIncomingValueForBlock(*BB);
-            phi->addIncoming(Incoming, New);
-          }
+      // Add incoming values to phi nodes that reference this block. The last
+      // latch block may need to be referenced by the first header, and any
+      // block with an exit edge may be referenced from outside the loop.
+      for (Value::use_iterator UI = BB->use_begin(), UE = BB->use_end();
+           UI != UE; ) {
+        PHINode *PN = dyn_cast<PHINode>(*UI++);
+        if (PN &&
+            ((BB == LatchBlock && It == Count - 1 && !CompletelyUnroll) ||
+             (!SuppressExitEdges && !L->contains(PN->getParent())))) {
+          Value *InVal = PN->getIncomingValueForBlock(BB);
+          // If this value was defined in the loop, take the value defined
+          // by the last iteration of the loop.
+          ValueMapTy::iterator VI = LastValueMap.find(InVal);
+          if (VI != LastValueMap.end())
+            InVal = VI->second;
+          PN->addIncoming(InVal, New);
         }
-
-      // Keep track of new headers and latches as we create them, so that
-      // we can insert the proper branches later.
-      if (*BB == Header)
-        Headers.push_back(New);
-      if (*BB == LatchBlock) {
-        Latches.push_back(New);
-
-        // Also, clear out the new latch's back edge so that it doesn't look
-        // like a new loop, so that it's amenable to being merged with adjacent
-        // blocks later on.
-        TerminatorInst *Term = New->getTerminator();
-        assert(L->contains(Term->getSuccessor(!ContinueOnTrue)));
-        assert(Term->getSuccessor(ContinueOnTrue) == LoopExit);
-        Term->setSuccessor(!ContinueOnTrue, NULL);
       }
-
-      NewBlocks.push_back(New);
     }
     
     // Remap all instructions in the most recent iteration
-    for (unsigned i = 0; i < NewBlocks.size(); ++i)
+    for (unsigned i = 0, e = NewBlocks.size(); i != e; ++i)
       for (BasicBlock::iterator I = NewBlocks[i]->begin(),
            E = NewBlocks[i]->end(); I != E; ++I)
         RemapInstruction(I, LastValueMap);
   }
-  
-  // The latch block exits the loop.  If there are any PHI nodes in the
-  // successor blocks, update them to use the appropriate values computed as the
-  // last iteration of the loop.
-  if (Count != 1) {
-    SmallPtrSet<PHINode*, 8> Users;
-    for (Value::use_iterator UI = LatchBlock->use_begin(),
-         UE = LatchBlock->use_end(); UI != UE; ++UI)
-      if (PHINode *phi = dyn_cast<PHINode>(*UI))
-        Users.insert(phi);
-    
-    BasicBlock *LastIterationBB = cast<BasicBlock>(LastValueMap[LatchBlock]);
-    for (SmallPtrSet<PHINode*,8>::iterator SI = Users.begin(), SE = Users.end();
-         SI != SE; ++SI) {
-      PHINode *PN = *SI;
-      Value *InVal = PN->removeIncomingValue(LatchBlock, false);
-      // If this value was defined in the loop, take the value defined by the
-      // last iteration of the loop.
-      if (Instruction *InValI = dyn_cast<Instruction>(InVal)) {
-        if (L->contains(InValI->getParent()))
-          InVal = LastValueMap[InVal];
-      }
-      PN->addIncoming(InVal, LastIterationBB);
-    }
-  }
-
-  // Now, if we're doing complete unrolling, loop over the PHI nodes in the
-  // original block, setting them to their incoming values.
-  if (CompletelyUnroll) {
-    BasicBlock *Preheader = L->getLoopPreheader();
-    for (unsigned i = 0, e = OrigPHINode.size(); i != e; ++i) {
-      PHINode *PN = OrigPHINode[i];
-      PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader));
-      Header->getInstList().erase(PN);
-    }
-  }
 
   // Now that all the basic blocks for the unrolled iterations are in place,
   // set up the branches to connect them.
-  for (unsigned i = 0, e = Latches.size(); i != e; ++i) {
+  for (unsigned It = 0; It != Count; ++It) {
     // The original branch was replicated in each unrolled iteration.
-    BranchInst *Term = cast<BranchInst>(Latches[i]->getTerminator());
+    BranchInst *Term = cast<BranchInst>(Latches[It]->getTerminator());
 
     // The branch destination.
-    unsigned j = (i + 1) % e;
-    BasicBlock *Dest = Headers[j];
+    unsigned NextIt = (It + 1) % Count;
+    BasicBlock *Dest = Headers[NextIt];
     bool NeedConditional = true;
+    bool HasExit = true;
 
-    // For a complete unroll, make the last iteration end with a branch
-    // to the exit block.
-    if (CompletelyUnroll && j == 0) {
+    // For a complete unroll, make the last iteration end with an
+    // unconditional branch to the exit block.
+    if (CompletelyUnroll && NextIt == 0) {
       Dest = LoopExit;
       NeedConditional = false;
     }
 
     // If we know the trip count or a multiple of it, we can safely use an
     // unconditional branch for some iterations.
-    if (j != BreakoutTrip && (TripMultiple == 0 || j % TripMultiple != 0)) {
+    if (NextIt != BreakoutTrip &&
+        (TripMultiple == 0 || NextIt % TripMultiple != 0)) {
       NeedConditional = false;
+      HasExit = false;
     }
 
     if (NeedConditional) {
@@ -338,24 +329,50 @@
         std::replace(Headers.begin(), Headers.end(), Dest, Fold);
       }
     }
+
+    // Special handling for the first iteration. If the first latch is
+    // now unconditionally branching to the second header, then it is
+    // no longer an exit node. Delete PHI references to it both from
+    // the first header and from outsie the loop.
+    if (It == 0)
+      for (Value::use_iterator UI = LatchBlock->use_begin(),
+           UE = LatchBlock->use_end(); UI != UE; ) {
+        PHINode *PN = dyn_cast<PHINode>(*UI++);
+        if (PN && (PN->getParent() == Header ? Count > 1 : !HasExit))
+          PN->removeIncomingValue(LatchBlock);
+      }
   }
   
-  // At this point, the code is well formed.  We now do a quick sweep over the
-  // inserted code, doing constant propagation and dead code elimination as we
-  // go.
-  const std::vector<BasicBlock*> &NewLoopBlocks = L->getBlocks();
-  for (std::vector<BasicBlock*>::const_iterator BB = NewLoopBlocks.begin(),
-       BBE = NewLoopBlocks.end(); BB != BBE; ++BB)
-    for (BasicBlock::iterator I = (*BB)->begin(), E = (*BB)->end(); I != E; ) {
+  // At this point, unrolling is complete and the code is well formed. 
+  // Now, do some simplifications.
+
+  // If we're doing complete unrolling, loop over the PHI nodes in the
+  // original block, setting them to their incoming values.
+  if (CompletelyUnroll) {
+    BasicBlock *Preheader = L->getLoopPreheader();
+    for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ) {
+      PHINode *PN = cast<PHINode>(I++);
+      PN->replaceAllUsesWith(PN->getIncomingValueForBlock(Preheader));
+      Header->getInstList().erase(PN);
+    }
+  }
+
+  // We now do a quick sweep over the inserted code, doing constant
+  // propagation and dead code elimination as we go.
+  for (Loop::block_iterator BI = L->block_begin(), BBE = L->block_end();
+       BI != BBE; ++BI) {
+    BasicBlock *BB = *BI;
+    for (BasicBlock::iterator I = BB->begin(), E = BB->end(); I != E; ) {
       Instruction *Inst = I++;
 
       if (isInstructionTriviallyDead(Inst))
-        (*BB)->getInstList().erase(Inst);
+        BB->getInstList().erase(Inst);
       else if (Constant *C = ConstantFoldInstruction(Inst)) {
         Inst->replaceAllUsesWith(C);
-        (*BB)->getInstList().erase(Inst);
+        BB->getInstList().erase(Inst);
       }
     }
+  }
 
   NumCompletelyUnrolled += CompletelyUnroll;
   ++NumUnrolled;

Added: llvm/trunk/test/Transforms/LoopUnroll/multiple-phis.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/LoopUnroll/multiple-phis.ll?rev=52645&view=auto

==============================================================================
--- llvm/trunk/test/Transforms/LoopUnroll/multiple-phis.ll (added)
+++ llvm/trunk/test/Transforms/LoopUnroll/multiple-phis.ll Mon Jun 23 16:29:41 2008
@@ -0,0 +1,51 @@
+; RUN: llvm-as < %s | opt -loop-unroll -unroll-count 6 -unroll-threshold 300 | llvm-dis > %t
+; RUN: grep {br label \%bbe} %t | count 12
+; RUN: grep {br i1 \%z} %t | count 3
+; RUN: grep {br i1 \%q} %t | count 6
+; RUN: grep call %t | count 12
+; RUN: grep urem %t | count 6
+; RUN: grep store %t | count 6
+; RUN: grep phi %t | count 11
+; RUN: grep {lcssa = phi} %t | count 2
+
+; This testcase uses
+;  - an unknown tripcount, but a known trip multiple of 2.
+;  - an unroll count of 6, so we should get 3 conditional branches
+;    in the loop.
+;  - values defined inside the loop and used outside, by phis that
+;    also use values defined elsewhere outside the loop.
+;  - a phi inside the loop that only uses values defined
+;    inside the loop and is only used inside the loop.
+
+declare i32 @foo()
+declare i32 @bar()
+
+define i32 @fib(i32 %n, i1 %a, i32* %p) nounwind {
+entry:
+        %n2 = mul i32 %n, 2
+        br i1 %a, label %bb, label %return
+
+bb: ; loop header block
+        %t0 = phi i32 [ 0, %entry ], [ %t1, %bbe ]
+        %td = urem i32 %t0, 7
+        %q = trunc i32 %td to i1
+        br i1 %q, label %bbt, label %bbf
+bbt:
+        %bbtv = call i32 @foo()
+        br label %bbe
+bbf:
+        %bbfv = call i32 @bar()
+        br label %bbe
+bbe: ; loop latch block
+        %bbpv = phi i32 [ %bbtv, %bbt ], [ %bbfv, %bbf ]
+        store i32 %bbpv, i32* %p
+        %t1 = add i32 %t0, 1
+        %z = icmp ne i32 %t1, %n2
+        br i1 %z, label %bb, label %return
+
+return:
+        %f = phi i32 [ -2, %entry ], [ %t0, %bbe ]
+        %g = phi i32 [ -3, %entry ], [ %t1, %bbe ]
+        %h = mul i32 %f, %g
+        ret i32 %h
+}

Added: llvm/trunk/test/Transforms/LoopUnroll/pr2253.ll
URL: http://llvm.org/viewvc/llvm-project/llvm/trunk/test/Transforms/LoopUnroll/pr2253.ll?rev=52645&view=auto

==============================================================================
--- llvm/trunk/test/Transforms/LoopUnroll/pr2253.ll (added)
+++ llvm/trunk/test/Transforms/LoopUnroll/pr2253.ll Mon Jun 23 16:29:41 2008
@@ -0,0 +1,21 @@
+; RUN: llvm-as < %s | opt -loop-unroll -unroll-count 2 | llvm-dis | grep add | count 2
+; PR2253
+
+; There's a use outside the loop, and the PHI needs an incoming edge for
+; each unrolled iteration, since the trip count is unknown and any iteration
+; could exit.
+
+define i32 @fib(i32 %n) nounwind {
+entry:
+        br i1 false, label %bb, label %return
+
+bb:
+        %t0 = phi i32 [ 0, %entry ], [ %t1, %bb ]
+        %t1 = add i32 %t0, 1
+        %c = icmp ne i32 %t0, %n
+        br i1 %c, label %bb, label %return
+
+return:
+        %f2.0.lcssa = phi i32 [ -1, %entry ], [ %t0, %bb ]
+        ret i32 %f2.0.lcssa
+}





More information about the llvm-commits mailing list