[llvm-commits] CVS: llvm/lib/Transforms/Utils/CodeExtractor.cpp

Chris Lattner lattner at cs.uiuc.edu
Sun Mar 14 16:38:06 PST 2004


Changes in directory llvm/lib/Transforms/Utils:

CodeExtractor.cpp updated: 1.5 -> 1.6

---
Log message:

Simplify the code a bit by making the collection of basic blocks to extract
a member of the class.  While we're at it, turn the collection into a set
instead of a vector to improve efficiency and make queries simpler.


---
Diffs of the changes:  (+39 -57)

Index: llvm/lib/Transforms/Utils/CodeExtractor.cpp
diff -u llvm/lib/Transforms/Utils/CodeExtractor.cpp:1.5 llvm/lib/Transforms/Utils/CodeExtractor.cpp:1.6
--- llvm/lib/Transforms/Utils/CodeExtractor.cpp:1.5	Sat Mar 13 22:01:47 2004
+++ llvm/lib/Transforms/Utils/CodeExtractor.cpp	Sun Mar 14 16:34:55 2004
@@ -26,16 +26,11 @@
 #include "Support/Debug.h"
 #include "Support/StringExtras.h"
 #include <algorithm>
-#include <map>
-#include <vector>
+#include <set>
 using namespace llvm;
 
 namespace {
 
-  inline bool contains(const std::vector<BasicBlock*> &V, const BasicBlock *BB){
-    return std::find(V.begin(), V.end(), BB) != V.end();
-  }
-
   /// getFunctionArg - Return a pointer to F's ARGNOth argument.
   ///
   Argument *getFunctionArg(Function *F, unsigned argno) {
@@ -49,19 +44,16 @@
     typedef std::vector<std::pair<unsigned, unsigned> > PhiValChangesTy;
     typedef std::map<PHINode*, PhiValChangesTy> PhiVal2ArgTy;
     PhiVal2ArgTy PhiVal2Arg;
-
+    std::set<BasicBlock*> BlocksToExtract;
   public:
     Function *ExtractCodeRegion(const std::vector<BasicBlock*> &code);
 
   private:
-    void findInputsOutputs(const std::vector<BasicBlock*> &code,
-                           Values &inputs,
-                           Values &outputs,
+    void findInputsOutputs(Values &inputs, Values &outputs,
                            BasicBlock *newHeader,
                            BasicBlock *newRootNode);
 
     void processPhiNodeInputs(PHINode *Phi,
-                              const std::vector<BasicBlock*> &code,
                               Values &inputs,
                               BasicBlock *newHeader,
                               BasicBlock *newRootNode);
@@ -71,15 +63,12 @@
     Function *constructFunction(const Values &inputs,
                                 const Values &outputs,
                                 BasicBlock *newRootNode, BasicBlock *newHeader,
-                                const std::vector<BasicBlock*> &code,
                                 Function *oldFunction, Module *M);
 
-    void moveCodeToFunction(const std::vector<BasicBlock*> &code,
-                            Function *newFunction);
+    void moveCodeToFunction(Function *newFunction);
 
     void emitCallAndSwitchStatement(Function *newFunction,
                                     BasicBlock *newHeader,
-                                    const std::vector<BasicBlock*> &code,
                                     Values &inputs,
                                     Values &outputs);
 
@@ -87,7 +76,6 @@
 }
 
 void CodeExtractor::processPhiNodeInputs(PHINode *Phi,
-                                         const std::vector<BasicBlock*> &code,
                                          Values &inputs,
                                          BasicBlock *codeReplacer,
                                          BasicBlock *newFuncRoot)
@@ -102,11 +90,11 @@
   for (unsigned i = 0, e = Phi->getNumIncomingValues(); i != e; ++i) {
     Value *phiVal = Phi->getIncomingValue(i);
     if (Instruction *Inst = dyn_cast<Instruction>(phiVal)) {
-      if (contains(code, Inst->getParent())) {
-        if (!contains(code, Phi->getIncomingBlock(i)))
+      if (BlocksToExtract.count(Inst->getParent())) {
+        if (!BlocksToExtract.count(Phi->getIncomingBlock(i)))
           IValEBB.push_back(i);
       } else {
-        if (contains(code, Phi->getIncomingBlock(i)))
+        if (BlocksToExtract.count(Phi->getIncomingBlock(i)))
           EValIBB.push_back(i);
         else
           EValEBB.push_back(i);
@@ -114,11 +102,11 @@
     } else if (Constant *Const = dyn_cast<Constant>(phiVal)) {
       // Constants are internal, but considered `external' if they are coming
       // from an external block.
-      if (!contains(code, Phi->getIncomingBlock(i)))
+      if (!BlocksToExtract.count(Phi->getIncomingBlock(i)))
         EValEBB.push_back(i);
     } else if (Argument *Arg = dyn_cast<Argument>(phiVal)) {
       // arguments are external
-      if (contains(code, Phi->getIncomingBlock(i)))
+      if (BlocksToExtract.count(Phi->getIncomingBlock(i)))
         EValIBB.push_back(i);
       else
         EValEBB.push_back(i);
@@ -184,14 +172,13 @@
 }
 
 
-void CodeExtractor::findInputsOutputs(const std::vector<BasicBlock*> &code,
-                                      Values &inputs,
+void CodeExtractor::findInputsOutputs(Values &inputs,
                                       Values &outputs,
                                       BasicBlock *newHeader,
                                       BasicBlock *newRootNode)
 {
-  for (std::vector<BasicBlock*>::const_iterator ci = code.begin(), 
-       ce = code.end(); ci != ce; ++ci) {
+  for (std::set<BasicBlock*>::const_iterator ci = BlocksToExtract.begin(), 
+       ce = BlocksToExtract.end(); ci != ce; ++ci) {
     BasicBlock *BB = *ci;
     for (BasicBlock::iterator BBi = BB->begin(), BBe = BB->end();
          BBi != BBe; ++BBi) {
@@ -200,7 +187,7 @@
       if (Instruction *I = dyn_cast<Instruction>(&*BBi)) {
         // If it's a phi node
         if (PHINode *Phi = dyn_cast<PHINode>(I)) {
-          processPhiNodeInputs(Phi, code, inputs, newHeader, newRootNode);
+          processPhiNodeInputs(Phi, inputs, newHeader, newRootNode);
         } else {
           // All other instructions go through the generic input finder
           // Loop over the operands of each instruction (inputs)
@@ -208,7 +195,7 @@
                op != opE; ++op) {
             if (Instruction *opI = dyn_cast<Instruction>(op->get())) {
               // Check if definition of this operand is within the loop
-              if (!contains(code, opI->getParent())) {
+              if (!BlocksToExtract.count(opI->getParent())) {
                 // add this operand to the inputs
                 inputs.push_back(opI);
               }
@@ -220,7 +207,7 @@
         for (Value::use_iterator use = I->use_begin(), useE = I->use_end();
              use != useE; ++use) {
           if (Instruction* inst = dyn_cast<Instruction>(*use)) {
-            if (!contains(code, inst->getParent())) {
+            if (!BlocksToExtract.count(inst->getParent())) {
               // add this op to the outputs
               outputs.push_back(I);
             }
@@ -276,11 +263,10 @@
                                            const Values &outputs,
                                            BasicBlock *newRootNode,
                                            BasicBlock *newHeader,
-                                           const std::vector<BasicBlock*> &code,
                                            Function *oldFunction, Module *M) {
   DEBUG(std::cerr << "inputs: " << inputs.size() << "\n");
   DEBUG(std::cerr << "outputs: " << outputs.size() << "\n");
-  BasicBlock *header = code[0];
+  BasicBlock *header = *BlocksToExtract.begin();
 
   // This function returns unsigned, outputs will go back by reference.
   Type *retTy = Type::UShortTy;
@@ -327,7 +313,7 @@
     for (std::vector<User*>::iterator use = Users.begin(), useE = Users.end();
          use != useE; ++use)
       if (Instruction* inst = dyn_cast<Instruction>(*use))
-        if (contains(code, inst->getParent()))
+        if (BlocksToExtract.count(inst->getParent()))
           inst->replaceUsesOfWith(inputs[i], getFunctionArg(newFunction, i));
   }
 
@@ -339,7 +325,7 @@
        i != e; ++i) {
     if (BranchInst *inst = dyn_cast<BranchInst>(*i)) {
       BasicBlock *BB = inst->getParent();
-      if (!contains(code, BB) && BB->getParent() == oldFunction) {
+      if (!BlocksToExtract.count(BB) && BB->getParent() == oldFunction) {
         // The BasicBlock which contains the branch is not in the region
         // modify the branch target to a new block
         inst->replaceUsesOfWith(header, newHeader);
@@ -350,29 +336,25 @@
   return newFunction;
 }
 
-void CodeExtractor::moveCodeToFunction(const std::vector<BasicBlock*> &code,
-                                       Function *newFunction)
+void CodeExtractor::moveCodeToFunction(Function *newFunction)
 {
-  Function *oldFunc = code[0]->getParent();
+  Function *oldFunc = (*BlocksToExtract.begin())->getParent();
   Function::BasicBlockListType &oldBlocks = oldFunc->getBasicBlockList();
     Function::BasicBlockListType &newBlocks = newFunction->getBasicBlockList();
 
-  for (std::vector<BasicBlock*>::const_iterator i = code.begin(), e =code.end();
-       i != e; ++i) {
-    BasicBlock *BB = *i;
-
+  for (std::set<BasicBlock*>::const_iterator i = BlocksToExtract.begin(),
+         e = BlocksToExtract.end(); i != e; ++i) {
     // Delete the basic block from the old function, and the list of blocks
-    oldBlocks.remove(BB);
+    oldBlocks.remove(*i);
 
     // Insert this basic block into the new function
-    newBlocks.push_back(BB);
+    newBlocks.push_back(*i);
   }
 }
 
 void
 CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
                                           BasicBlock *codeReplacer,
-                                          const std::vector<BasicBlock*> &code,
                                           Values &inputs,
                                           Values &outputs)
 {
@@ -399,7 +381,7 @@
       for (std::vector<User*>::iterator use = Users.begin(), useE =Users.end();
            use != useE; ++use) {
         if (Instruction* inst = dyn_cast<Instruction>(*use)) {
-          if (!contains(code, inst->getParent())) {
+          if (!BlocksToExtract.count(inst->getParent())) {
             inst->replaceUsesOfWith(*i, load);
           }
         }
@@ -425,8 +407,8 @@
   // Since there may be multiple exits from the original region, make the new
   // function return an unsigned, switch on that number
   unsigned switchVal = 0;
-  for (std::vector<BasicBlock*>::const_iterator i =code.begin(), e = code.end();
-       i != e; ++i) {
+  for (std::set<BasicBlock*>::const_iterator i = BlocksToExtract.begin(),
+         e = BlocksToExtract.end(); i != e; ++i) {
     BasicBlock *BB = *i;
 
     // rewrite the terminator of the original BasicBlock
@@ -436,16 +418,14 @@
       // Restore values just before we exit
       // FIXME: Use a GetElementPtr to bunch the outputs in a struct
       for (unsigned outIdx = 0, outE = outputs.size(); outIdx != outE; ++outIdx)
-      {
         new StoreInst(outputs[outIdx],
                       getFunctionArg(newFunction, outIdx),
                       brInst);
-      }
 
       // Rewrite branches into exits which return a value based on which
       // exit we take from this function
       if (brInst->isUnconditional()) {
-        if (!contains(code, brInst->getSuccessor(0))) {
+        if (!BlocksToExtract.count(brInst->getSuccessor(0))) {
           ConstantUInt *brVal = ConstantUInt::get(Type::UShortTy, switchVal);
           ReturnInst *newRet = new ReturnInst(brVal);
           // add a new target to the switch
@@ -461,7 +441,7 @@
         // to two new blocks, each of which returns a different code.
         for (unsigned idx = 0; idx < 2; ++idx) {
           BasicBlock *oldTarget = brInst->getSuccessor(idx);
-          if (!contains(code, oldTarget)) {
+          if (!BlocksToExtract.count(oldTarget)) {
             // add a new basic block which returns the appropriate value
             BasicBlock *newTarget = new BasicBlock("newTarget", newFunction);
             ConstantUInt *brVal = ConstantUInt::get(Type::UShortTy, switchVal);
@@ -475,13 +455,15 @@
           }
         }
       }
+    } else if (SwitchInst *swTerm = dyn_cast<SwitchInst>(term)) {
+
+      assert(0 && "Cannot handle switch instructions just yet.");
+
     } else if (ReturnInst *retTerm = dyn_cast<ReturnInst>(term)) {
       assert(0 && "Cannot handle return instructions just yet.");
       // FIXME: what if the terminator is a return!??!
       // Need to rewrite: add new basic block, move the return there
       // treat the original as an unconditional branch to that basicblock
-    } else if (SwitchInst *swTerm = dyn_cast<SwitchInst>(term)) {
-      assert(0 && "Cannot handle switch instructions just yet.");
     } else if (InvokeInst *invInst = dyn_cast<InvokeInst>(term)) {
       assert(0 && "Cannot handle invoke instructions just yet.");
     } else {
@@ -514,7 +496,8 @@
   //  * Add allocas for defs, pass as args by reference
   //  * Pass in uses as args
   // 3) Move code region, add call instr to func
-  // 
+  //
+  BlocksToExtract.insert(code.begin(), code.end());
 
   Values inputs, outputs;
 
@@ -548,19 +531,18 @@
   // blocks moving to a new function.
   // SOLUTION: move Phi nodes out of the loop header into the codeReplacer, pass
   // the values as parameters to the function
-  findInputsOutputs(code, inputs, outputs, codeReplacer, newFuncRoot);
+  findInputsOutputs(inputs, outputs, codeReplacer, newFuncRoot);
 
   // Step 2: Construct new function based on inputs/outputs,
   // Add allocas for all defs
   Function *newFunction = constructFunction(inputs, outputs, newFuncRoot, 
-                                            codeReplacer, code, 
-                                            oldFunction, module);
+                                            codeReplacer, oldFunction, module);
 
   rewritePhiNodes(newFunction, newFuncRoot);
 
-  emitCallAndSwitchStatement(newFunction, codeReplacer, code, inputs, outputs);
+  emitCallAndSwitchStatement(newFunction, codeReplacer, inputs, outputs);
 
-  moveCodeToFunction(code, newFunction);
+  moveCodeToFunction(newFunction);
 
   DEBUG(if (verifyFunction(*newFunction)) abort());
   return newFunction;





More information about the llvm-commits mailing list