[llvm-commits] [parallel] CVS: llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp

Misha Brukman brukman at cs.uiuc.edu
Fri May 7 16:38:01 PDT 2004


Changes in directory llvm/lib/Transforms/Parallel:

ParallelCallsToThreads.cpp updated: 1.1.2.2 -> 1.1.2.3

---
Log message:

Convert calls to threads for _each_ parallel region, as the values we join() on
depend on the region we spawned off.


---
Diffs of the changes:  (+47 -25)

Index: llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp
diff -u llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp:1.1.2.2 llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp:1.1.2.3
--- llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp:1.1.2.2	Sat Apr 17 20:10:58 2004
+++ llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp	Fri May  7 16:37:56 2004
@@ -11,10 +11,10 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "llvm/Constant.h"
 #include "llvm/DerivedTypes.h"
 #include "llvm/Module.h"
-#include "llvm/iOther.h"
-#include "llvm/iTerminators.h"
+#include "llvm/Instructions.h"
 #include "llvm/Pass.h"
 #include "llvm/Type.h"
 #include "llvm/Analysis/ParallelInfo.h"
@@ -30,6 +30,7 @@
   ///
   struct PCallToThreads : public FunctionPass {
     Type *startTy;
+    typedef std::vector<BasicBlock*> BBVec;
 
   public:
     PCallToThreads() {
@@ -57,6 +58,19 @@
 
 } // End anonymous namespace
 
+static bool findPbr(Value *V, ParaBrInst *pbr) {
+  if (V == pbr)
+    return true;
+  else if (PHINode *phi = dyn_cast<PHINode>(V)) {
+    bool Found = false;
+    for (unsigned i = 0, e = phi->getNumIncomingValues(); i != e; ++i)
+      if (findPbr(phi->getIncomingValue(i), pbr))
+        return true;
+    return false;
+  } else
+    return false;
+}
+
 /// runOnFunction - 
 ///
 bool PCallToThreads::runOnFunction(Function &F) {
@@ -66,11 +80,12 @@
   // Convert parallel calls to pthread_create() invocations
   for (ParallelInfo::iterator i = PI.begin(), e = PI.end(); i != e; ++i) {
     ParallelSeq *PS = *i;
-    std::vector<Value*> JoinValues;
-    for (ParallelSeq::riterator r = PS->rbegin(),
-           re = PS->rend(); r != re; ++r) {
+    for (ParallelSeq::riterator r = PS->rbegin(), re = PS->rend(); 
+         r != re; ++r) {
+      std::vector<Value*> JoinValues;
       ParallelRegion *PR = *r;
       std::vector<BasicBlock*> RegionBlocks(PR->begin(), PR->end());
+      if (RegionBlocks.empty()) continue;
 
       // Ensure that there is only one block in this region
       assert(RegionBlocks.size() == 1 && "Parallel region has > 1 BB");
@@ -84,15 +99,17 @@
       
       // Replace call with __llvm_thread_create
       Function *ThCreate = getFuncThreadStart(*F.getParent());
-      assert(OldCall->getNumOperands() == 2 && 
-             "Can only threadify calls with one argument!");
+      assert(OldCall->getNumOperands() <= 2 && 
+             "Can only threadify calls with at most one argument!");
 
       TerminatorInst *TI = BB->getTerminator();
       CastInst *funcPtr = new CastInst(OldCall->getOperand(0), startTy,
                                        "cast_ptr", TI);
-      CastInst *funcVal = new CastInst(OldCall->getOperand(1),
-                                       PointerType::get(Type::SByteTy),
-                                       "cast_val", TI);
+      Value *funcVal = Constant::getNullValue(PointerType::get(Type::SByteTy));
+      if (OldCall->getNumOperands() == 2)
+        funcVal = new CastInst(OldCall->getOperand(1),
+                               PointerType::get(Type::SByteTy),
+                               "cast_val", TI);
       std::vector<Value*> Args;
       Args.push_back(funcPtr);
       Args.push_back(funcVal);
@@ -101,24 +118,29 @@
 
       OldCall->getParent()->getInstList().erase(OldCall);
 
+      for (std::vector<BasicBlock*>::iterator BBr = RegionBlocks.begin(),
+             BBre = RegionBlocks.end(); BBr != BBre; ++BBr)
+        PR->removeBasicBlock(*BBr);
+
+      // Convert llvm.join() intrinsic to __llvm_thread_join() calls
+      // Get join function call position/bb
+      ParaBrInst *Pbr = dyn_cast<ParaBrInst>(PS->getHeader()->getTerminator());
+      Function *ThJoin = getFuncThreadJoin(*F.getParent());
+      assert(Pbr && "Terminator of parallel sequence header is not a Pbr!");
+      BBVec &JoinBlocks = PR->getJoinBlocks();
+      for (BBVec::iterator j = JoinBlocks.begin(), je = JoinBlocks.end();
+           j != je; ++j)
+        for (BasicBlock::iterator i = (*j)->begin(), e=(*j)->end(); i!=e; ++i)
+          if (CallInst *CI = dyn_cast<CallInst>(i))
+            if (CI->getCalledFunction()->getName() == "llvm.join" &&
+                findPbr(CI->getOperand(1), Pbr))
+              for (std::vector<Value*>::iterator val = JoinValues.begin(), 
+                     ve = JoinValues.end(); val != ve; ++val)
+                CallInst *Join = new CallInst(ThJoin, *val, "", CI);
+
       Changed = true;
     }
 
-    // Convert llvm.join() intrinsic to __llvm_thread_join() calls
-    // Get join function call position/bb
-    ParaBrInst *Pbr = dyn_cast<ParaBrInst>(PS->getHeader()->getTerminator());
-    assert(Pbr && "Terminator of parallel sequence header is not a Pbr!");
-    
-    std::vector<User*> Users(Pbr->use_begin(), Pbr->use_end());
-    assert(Users.size() == 1 && "Must have unique user of Pbr");
-    CallInst *JoinCall = cast<CallInst>(Users[0]);
-    assert(JoinCall && "Pbr result used in something other than call!");
-    
-    Function *ThJoin = getFuncThreadJoin(*F.getParent());
-    TerminatorInst *JoinBBTerm = JoinCall->getParent()->getTerminator();
-    for (std::vector<Value*>::iterator val = JoinValues.begin(), 
-           ve = JoinValues.end(); val != ve; ++val)
-      CallInst *Join = new CallInst(ThJoin, *val, "", JoinBBTerm);
   }
 
   return Changed;





More information about the llvm-commits mailing list