[llvm-commits] [parallel] CVS: llvm/lib/Transforms/Parallel/LowerParaBr.cpp
Misha Brukman
brukman at cs.uiuc.edu
Fri May 7 16:37:01 PDT 2004
Changes in directory llvm/lib/Transforms/Parallel:
LowerParaBr.cpp updated: 1.1.2.1 -> 1.1.2.2
---
Log message:
llvm.join no longer necessarily uses pbr directly, plus there could be >1 join
for every pbr, so we need to remove them all in a given parallel region.
---
Diffs of the changes: (+38 -15)
Index: llvm/lib/Transforms/Parallel/LowerParaBr.cpp
diff -u llvm/lib/Transforms/Parallel/LowerParaBr.cpp:1.1.2.1 llvm/lib/Transforms/Parallel/LowerParaBr.cpp:1.1.2.2
--- llvm/lib/Transforms/Parallel/LowerParaBr.cpp:1.1.2.1 Sat Apr 17 19:45:51 2004
+++ llvm/lib/Transforms/Parallel/LowerParaBr.cpp Fri May 7 16:37:09 2004
@@ -43,6 +43,7 @@
private:
void straightenSequence(ParallelSeq *PS);
+ Function* getJoinIntrinsic(Module *M);
};
RegisterOpt<LowerParaBr> X("lowerpbr", "Lower pbr to sequential code");
@@ -66,6 +67,10 @@
return Changed;
}
+static bool contains(std::vector<BasicBlock*> haystack, BasicBlock *needle) {
+ return std::find(haystack.begin(), haystack.end(), needle) != haystack.end();
+}
+
/// straightenSequence - Recursively process parallel sequences
///
void LowerParaBr::straightenSequence(ParallelSeq *PS) {
@@ -79,27 +84,45 @@
ci != ce; ++ci)
straightenSequence(*ci);
- // Stitch previous region to the current one by rewriting the last branch
+ // Stitch previous region to the current one by all branches to the join
+ // block instead branch to the first block of the second region
if (PrevPR) {
- std::vector<BasicBlock*> &PrevBlocks = PrevPR->getBlocks(),
+ std::vector<BasicBlock*> &PrevJoinBlocks = PrevPR->getJoinBlocks(),
&PRBlocks = PR->getBlocks();
- BasicBlock *PrevLastBB = PrevBlocks.back(), *PRFirstBB = PRBlocks[0];
- TerminatorInst *TI = PrevLastBB->getTerminator();
- new BranchInst(PRFirstBB, TI);
- PrevLastBB->getInstList().erase(TI);
+ BasicBlock *PRFirstBB = PRBlocks[0];
+ for (std::vector<BasicBlock*>::iterator j = PrevJoinBlocks.begin(),
+ je = PrevJoinBlocks.end(); j != je; ++j) {
+ std::vector<User*> JUsers((*j)->use_begin(), (*j)->use_end());
+ for (std::vector<User*>::iterator u = JUsers.begin(), ue = JUsers.end();
+ u != ue; ++u)
+ if (Instruction *I = dyn_cast<Instruction>(*u))
+ if (contains(PR->getBlocks(), I->getParent()))
+ I->replaceUsesOfWith(*j, PRFirstBB);
+ }
}
}
- // Remove the call to llvm.join()
- ParaBrInst *pbr = dyn_cast<ParaBrInst>(PS->getHeader()->getTerminator());
- assert(pbr && "pbr not a terminator of ParallelSeq header!");
- std::vector<User*> Users(pbr->use_begin(), pbr->use_end());
+ Module *M = PS->getHeader()->getParent()->getParent();
+ Function *JoinIntr = getJoinIntrinsic(M);
+ std::vector<User*> Users(JoinIntr->use_begin(), JoinIntr->use_end());
+ // Remove the calls to llvm.join()
for (std::vector<User*>::iterator use = Users.begin(), ue = Users.end();
use != ue; ++use)
- if (Instruction *Inst = dyn_cast<Instruction>(*use))
- Inst->getParent()->getInstList().erase(Inst);
+ if (CallInst *CI = dyn_cast<CallInst>(*use))
+ if (contains(PR->getJoinBlocks(), CI->getParent()) ||
+ contains(PrevPR->getJoinBlocks(), CI->getParent()))
+ CI->getParent()->getInstList().erase(CI);
+
+ // Remove pbr
+ ParaBrInst *pbr = dyn_cast<ParaBrInst>(PS->getHeader()->getTerminator());
+ //assert(pbr && "Terminator of ParaSeq header is not a pbr");
+ if (pbr) {
+ new BranchInst(pbr->getSuccessor(0), pbr);
+ pbr->getParent()->getInstList().erase(pbr);
+ }
+}
- // Rewrite pbr as a regular branch
- new BranchInst(pbr->getSuccessor(0), pbr);
- pbr->getParent()->getInstList().erase(pbr);
+Function* LowerParaBr::getJoinIntrinsic(Module *M) {
+ return M->getOrInsertFunction("llvm.join", Type::VoidTy,
+ PointerType::get(Type::SByteTy), 0);
}
More information about the llvm-commits
mailing list