[llvm-commits] [parallel] CVS: llvm/lib/Transforms/Parallel/LowerParaBr.cpp

Misha Brukman brukman at cs.uiuc.edu
Mon May 17 19:25:01 PDT 2004


Changes in directory llvm/lib/Transforms/Parallel:

LowerParaBr.cpp updated: 1.1.2.2 -> 1.1.2.3

---
Log message:

Create a unified join block which contains all the join calls for the region


---
Diffs of the changes:  (+75 -8)

Index: llvm/lib/Transforms/Parallel/LowerParaBr.cpp
diff -u llvm/lib/Transforms/Parallel/LowerParaBr.cpp:1.1.2.2 llvm/lib/Transforms/Parallel/LowerParaBr.cpp:1.1.2.3
--- llvm/lib/Transforms/Parallel/LowerParaBr.cpp:1.1.2.2	Fri May  7 16:37:09 2004
+++ llvm/lib/Transforms/Parallel/LowerParaBr.cpp	Mon May 17 19:24:12 2004
@@ -16,10 +16,11 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/BasicBlock.h"
 #include "llvm/DerivedTypes.h"
-#include "llvm/Module.h"
 #include "llvm/iOther.h"
 #include "llvm/iTerminators.h"
+#include "llvm/Module.h"
 #include "llvm/Pass.h"
 #include "llvm/Type.h"
 #include "llvm/Analysis/ParallelInfo.h"
@@ -44,6 +45,7 @@
   private:
     void straightenSequence(ParallelSeq *PS);
     Function* getJoinIntrinsic(Module *M);    
+    Function* getFuncThreadJoin(Module &M);
   };
 
   RegisterOpt<LowerParaBr> X("lowerpbr", "Lower pbr to sequential code");
@@ -67,10 +69,21 @@
   return Changed;
 }
 
-static bool contains(std::vector<BasicBlock*> haystack, BasicBlock *needle) {
+static inline bool
+contains(const std::vector<BasicBlock*> &haystack, const BasicBlock *needle) {
   return std::find(haystack.begin(), haystack.end(), needle) != haystack.end();
 }
 
+static inline bool
+containsJoin(const std::vector<CallInst*> &haystack, const CallInst *needle) {  
+  for (std::vector<CallInst*>::const_iterator h = haystack.begin(), 
+         he = haystack.end(); h != he; ++h)
+    if ((*h)->getOperand(1) == needle->getOperand(1)) 
+      return true;
+
+  return false;
+}
+
 /// straightenSequence - Recursively process parallel sequences
 ///
 void LowerParaBr::straightenSequence(ParallelSeq *PS) {
@@ -80,15 +93,16 @@
   {
     PrevPR = PR;
     PR = *i;
+#if 0
     for (ParallelRegion::seqiterator ci = PR->seqbegin(), ce = PR->seqend(); 
          ci != ce; ++ci) 
       straightenSequence(*ci);
-
-    // Stitch previous region to the current one by all branches to the join
-    // block instead branch to the first block of the second region
+#endif
     if (PrevPR) {
+      // Stitch previous region to the current one by all branches to the join
+      // block instead branch to the first block of the second region
       std::vector<BasicBlock*> &PrevJoinBlocks = PrevPR->getJoinBlocks(),
-        &PRBlocks = PR->getBlocks();
+        &PRJoinBlocks = PR->getJoinBlocks(), &PRBlocks = PR->getBlocks();
       BasicBlock *PRFirstBB = PRBlocks[0];
       for (std::vector<BasicBlock*>::iterator j = PrevJoinBlocks.begin(),
              je = PrevJoinBlocks.end(); j != je; ++j) {
@@ -96,9 +110,55 @@
         for (std::vector<User*>::iterator u = JUsers.begin(), ue = JUsers.end();
              u != ue; ++u)
           if (Instruction *I = dyn_cast<Instruction>(*u))
-            if (contains(PR->getBlocks(), I->getParent()))
+            if (contains(PrevPR->getBlocks(), I->getParent()))
               I->replaceUsesOfWith(*j, PRFirstBB);
       }
+
+      // Make a new join block that will house all the (coallesced) joins from
+      // both regions
+      BasicBlock *SumJoins = new BasicBlock("allJoins", PRFirstBB->getParent());
+
+      // Find the unique set of calls to the llvm.join intrinsic
+      std::vector<CallInst*> AllJoins;
+      for (std::vector<BasicBlock*>::iterator i = PrevJoinBlocks.begin(),
+             e = PrevJoinBlocks.end(); i != e; ++i) {
+        BasicBlock *BB = *i;
+        for (BasicBlock::iterator j = BB->begin(), je = BB->end(); j!=je; ++j)
+          if (CallInst *CI = dyn_cast<CallInst>(j)) 
+            if (CI->getCalledFunction()->getName() == "__llvm_thread_join")
+              if (!containsJoin(AllJoins, CI))
+                AllJoins.push_back(CI);
+      }
+      for (std::vector<BasicBlock*>::iterator i = PRJoinBlocks.begin(),
+             e = PRJoinBlocks.end(); i != e; ++i) {
+        BasicBlock *BB = *i;
+        for (BasicBlock::iterator j = BB->begin(), je = BB->end(); j!=je; ++j)
+          if (CallInst *CI = dyn_cast<CallInst>(j)) 
+            if (CI->getCalledFunction()->getName() == "__llvm_thread_join")
+              if (!containsJoin(AllJoins, CI))
+                AllJoins.push_back(CI);
+      }
+
+      // Add all unique joins to the coallesced joins block
+      for (std::vector<CallInst*>::iterator c = AllJoins.begin(), 
+             ce = AllJoins.end(); c != ce; ++c)
+        SumJoins->getInstList().push_back((*c)->clone());
+
+      // Terminate the Joins block with a branch to the post-join code
+      Instruction *TI = PRJoinBlocks[0]->getTerminator();
+      if (TI)
+        SumJoins->getInstList().push_back(TI->clone());
+
+      // Branch to the all-joins block from the main code
+      for (std::vector<BasicBlock*>::iterator j = PRJoinBlocks.begin(),
+             je = PRJoinBlocks.end(); j != je; ++j) {
+        std::vector<User*> JUsers((*j)->use_begin(), (*j)->use_end());
+        for (std::vector<User*>::iterator u = JUsers.begin(), ue = JUsers.end();
+             u != ue; ++u)
+          if (Instruction *I = dyn_cast<Instruction>(*u))
+            if (contains(PR->getBlocks(), I->getParent()))
+              I->replaceUsesOfWith(*j, SumJoins);
+      }
     }
   }
 
@@ -115,7 +175,6 @@
           
   // Remove pbr
   ParaBrInst *pbr = dyn_cast<ParaBrInst>(PS->getHeader()->getTerminator());
-  //assert(pbr && "Terminator of ParaSeq header is not a pbr");
   if (pbr) {
     new BranchInst(pbr->getSuccessor(0), pbr);
     pbr->getParent()->getInstList().erase(pbr);
@@ -125,4 +184,12 @@
 Function* LowerParaBr::getJoinIntrinsic(Module *M) {
   return M->getOrInsertFunction("llvm.join", Type::VoidTy,
                                 PointerType::get(Type::SByteTy), 0);
+}
+
+/// getFuncThreadJoin -
+///
+Function* LowerParaBr::getFuncThreadJoin(Module &M) {
+  // void __llvm_thread_join(int);
+  return M.getOrInsertFunction("__llvm_thread_join", Type::VoidTy, Type::IntTy, 
+                               0);
 }





More information about the llvm-commits mailing list