[llvm-commits] [parallel] CVS: llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp
Misha Brukman
brukman at cs.uiuc.edu
Fri May 7 16:38:01 PDT 2004
Changes in directory llvm/lib/Transforms/Parallel:
ParallelCallsToThreads.cpp updated: 1.1.2.2 -> 1.1.2.3
---
Log message:
Convert calls to threads for _each_ parallel region, as the values we join() on
depend on the region we spawned off.
---
Diffs of the changes: (+47 -25)
Index: llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp
diff -u llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp:1.1.2.2 llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp:1.1.2.3
--- llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp:1.1.2.2 Sat Apr 17 20:10:58 2004
+++ llvm/lib/Transforms/Parallel/ParallelCallsToThreads.cpp Fri May 7 16:37:56 2004
@@ -11,10 +11,10 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/Constant.h"
#include "llvm/DerivedTypes.h"
#include "llvm/Module.h"
-#include "llvm/iOther.h"
-#include "llvm/iTerminators.h"
+#include "llvm/Instructions.h"
#include "llvm/Pass.h"
#include "llvm/Type.h"
#include "llvm/Analysis/ParallelInfo.h"
@@ -30,6 +30,7 @@
///
struct PCallToThreads : public FunctionPass {
Type *startTy;
+ typedef std::vector<BasicBlock*> BBVec;
public:
PCallToThreads() {
@@ -57,6 +58,19 @@
} // End anonymous namespace
+static bool findPbr(Value *V, ParaBrInst *pbr) {
+ if (V == pbr)
+ return true;
+ else if (PHINode *phi = dyn_cast<PHINode>(V)) {
+ bool Found = false;
+ for (unsigned i = 0, e = phi->getNumIncomingValues(); i != e; ++i)
+ if (findPbr(phi->getIncomingValue(i), pbr))
+ return true;
+ return false;
+ } else
+ return false;
+}
+
/// runOnFunction -
///
bool PCallToThreads::runOnFunction(Function &F) {
@@ -66,11 +80,12 @@
// Convert parallel calls to pthread_create() invocations
for (ParallelInfo::iterator i = PI.begin(), e = PI.end(); i != e; ++i) {
ParallelSeq *PS = *i;
- std::vector<Value*> JoinValues;
- for (ParallelSeq::riterator r = PS->rbegin(),
- re = PS->rend(); r != re; ++r) {
+ for (ParallelSeq::riterator r = PS->rbegin(), re = PS->rend();
+ r != re; ++r) {
+ std::vector<Value*> JoinValues;
ParallelRegion *PR = *r;
std::vector<BasicBlock*> RegionBlocks(PR->begin(), PR->end());
+ if (RegionBlocks.empty()) continue;
// Ensure that there is only one block in this region
assert(RegionBlocks.size() == 1 && "Parallel region has > 1 BB");
@@ -84,15 +99,17 @@
// Replace call with __llvm_thread_create
Function *ThCreate = getFuncThreadStart(*F.getParent());
- assert(OldCall->getNumOperands() == 2 &&
- "Can only threadify calls with one argument!");
+ assert(OldCall->getNumOperands() <= 2 &&
+ "Can only threadify calls with at most one argument!");
TerminatorInst *TI = BB->getTerminator();
CastInst *funcPtr = new CastInst(OldCall->getOperand(0), startTy,
"cast_ptr", TI);
- CastInst *funcVal = new CastInst(OldCall->getOperand(1),
- PointerType::get(Type::SByteTy),
- "cast_val", TI);
+ Value *funcVal = Constant::getNullValue(PointerType::get(Type::SByteTy));
+ if (OldCall->getNumOperands() == 2)
+ funcVal = new CastInst(OldCall->getOperand(1),
+ PointerType::get(Type::SByteTy),
+ "cast_val", TI);
std::vector<Value*> Args;
Args.push_back(funcPtr);
Args.push_back(funcVal);
@@ -101,24 +118,29 @@
OldCall->getParent()->getInstList().erase(OldCall);
+ for (std::vector<BasicBlock*>::iterator BBr = RegionBlocks.begin(),
+ BBre = RegionBlocks.end(); BBr != BBre; ++BBr)
+ PR->removeBasicBlock(*BBr);
+
+ // Convert llvm.join() intrinsic to __llvm_thread_join() calls
+ // Get join function call position/bb
+ ParaBrInst *Pbr = dyn_cast<ParaBrInst>(PS->getHeader()->getTerminator());
+ Function *ThJoin = getFuncThreadJoin(*F.getParent());
+ assert(Pbr && "Terminator of parallel sequence header is not a Pbr!");
+ BBVec &JoinBlocks = PR->getJoinBlocks();
+ for (BBVec::iterator j = JoinBlocks.begin(), je = JoinBlocks.end();
+ j != je; ++j)
+ for (BasicBlock::iterator i = (*j)->begin(), e=(*j)->end(); i!=e; ++i)
+ if (CallInst *CI = dyn_cast<CallInst>(i))
+ if (CI->getCalledFunction()->getName() == "llvm.join" &&
+ findPbr(CI->getOperand(1), Pbr))
+ for (std::vector<Value*>::iterator val = JoinValues.begin(),
+ ve = JoinValues.end(); val != ve; ++val)
+ CallInst *Join = new CallInst(ThJoin, *val, "", CI);
+
Changed = true;
}
- // Convert llvm.join() intrinsic to __llvm_thread_join() calls
- // Get join function call position/bb
- ParaBrInst *Pbr = dyn_cast<ParaBrInst>(PS->getHeader()->getTerminator());
- assert(Pbr && "Terminator of parallel sequence header is not a Pbr!");
-
- std::vector<User*> Users(Pbr->use_begin(), Pbr->use_end());
- assert(Users.size() == 1 && "Must have unique user of Pbr");
- CallInst *JoinCall = cast<CallInst>(Users[0]);
- assert(JoinCall && "Pbr result used in something other than call!");
-
- Function *ThJoin = getFuncThreadJoin(*F.getParent());
- TerminatorInst *JoinBBTerm = JoinCall->getParent()->getTerminator();
- for (std::vector<Value*>::iterator val = JoinValues.begin(),
- ve = JoinValues.end(); val != ve; ++val)
- CallInst *Join = new CallInst(ThJoin, *val, "", JoinBBTerm);
}
return Changed;
More information about the llvm-commits
mailing list