[llvm-commits] [parallel] CVS: llvm/lib/Transforms/Parallel/LowerParaBr.cpp
Misha Brukman
brukman at cs.uiuc.edu
Mon May 17 19:25:01 PDT 2004
Changes in directory llvm/lib/Transforms/Parallel:
LowerParaBr.cpp updated: 1.1.2.2 -> 1.1.2.3
---
Log message:
Create a unified join block which contains all the join calls for the region
---
Diffs of the changes: (+75 -8)
Index: llvm/lib/Transforms/Parallel/LowerParaBr.cpp
diff -u llvm/lib/Transforms/Parallel/LowerParaBr.cpp:1.1.2.2 llvm/lib/Transforms/Parallel/LowerParaBr.cpp:1.1.2.3
--- llvm/lib/Transforms/Parallel/LowerParaBr.cpp:1.1.2.2 Fri May 7 16:37:09 2004
+++ llvm/lib/Transforms/Parallel/LowerParaBr.cpp Mon May 17 19:24:12 2004
@@ -16,10 +16,11 @@
//
//===----------------------------------------------------------------------===//
+#include "llvm/BasicBlock.h"
#include "llvm/DerivedTypes.h"
-#include "llvm/Module.h"
#include "llvm/iOther.h"
#include "llvm/iTerminators.h"
+#include "llvm/Module.h"
#include "llvm/Pass.h"
#include "llvm/Type.h"
#include "llvm/Analysis/ParallelInfo.h"
@@ -44,6 +45,7 @@
private:
void straightenSequence(ParallelSeq *PS);
Function* getJoinIntrinsic(Module *M);
+ Function* getFuncThreadJoin(Module &M);
};
RegisterOpt<LowerParaBr> X("lowerpbr", "Lower pbr to sequential code");
@@ -67,10 +69,21 @@
return Changed;
}
-static bool contains(std::vector<BasicBlock*> haystack, BasicBlock *needle) {
+static inline bool
+contains(const std::vector<BasicBlock*> &haystack, const BasicBlock *needle) {
return std::find(haystack.begin(), haystack.end(), needle) != haystack.end();
}
+static inline bool
+containsJoin(const std::vector<CallInst*> &haystack, const CallInst *needle) {
+ for (std::vector<CallInst*>::const_iterator h = haystack.begin(),
+ he = haystack.end(); h != he; ++h)
+ if ((*h)->getOperand(1) == needle->getOperand(1))
+ return true;
+
+ return false;
+}
+
/// straightenSequence - Recursively process parallel sequences
///
void LowerParaBr::straightenSequence(ParallelSeq *PS) {
@@ -80,15 +93,16 @@
{
PrevPR = PR;
PR = *i;
+#if 0
for (ParallelRegion::seqiterator ci = PR->seqbegin(), ce = PR->seqend();
ci != ce; ++ci)
straightenSequence(*ci);
-
- // Stitch previous region to the current one by all branches to the join
- // block instead branch to the first block of the second region
+#endif
if (PrevPR) {
+ // Stitch previous region to the current one by all branches to the join
+ // block instead branch to the first block of the second region
std::vector<BasicBlock*> &PrevJoinBlocks = PrevPR->getJoinBlocks(),
- &PRBlocks = PR->getBlocks();
+ &PRJoinBlocks = PR->getJoinBlocks(), &PRBlocks = PR->getBlocks();
BasicBlock *PRFirstBB = PRBlocks[0];
for (std::vector<BasicBlock*>::iterator j = PrevJoinBlocks.begin(),
je = PrevJoinBlocks.end(); j != je; ++j) {
@@ -96,9 +110,55 @@
for (std::vector<User*>::iterator u = JUsers.begin(), ue = JUsers.end();
u != ue; ++u)
if (Instruction *I = dyn_cast<Instruction>(*u))
- if (contains(PR->getBlocks(), I->getParent()))
+ if (contains(PrevPR->getBlocks(), I->getParent()))
I->replaceUsesOfWith(*j, PRFirstBB);
}
+
+ // Make a new join block that will house all the (coallesced) joins from
+ // both regions
+ BasicBlock *SumJoins = new BasicBlock("allJoins", PRFirstBB->getParent());
+
+ // Find the unique set of calls to the llvm.join intrinsic
+ std::vector<CallInst*> AllJoins;
+ for (std::vector<BasicBlock*>::iterator i = PrevJoinBlocks.begin(),
+ e = PrevJoinBlocks.end(); i != e; ++i) {
+ BasicBlock *BB = *i;
+ for (BasicBlock::iterator j = BB->begin(), je = BB->end(); j!=je; ++j)
+ if (CallInst *CI = dyn_cast<CallInst>(j))
+ if (CI->getCalledFunction()->getName() == "__llvm_thread_join")
+ if (!containsJoin(AllJoins, CI))
+ AllJoins.push_back(CI);
+ }
+ for (std::vector<BasicBlock*>::iterator i = PRJoinBlocks.begin(),
+ e = PRJoinBlocks.end(); i != e; ++i) {
+ BasicBlock *BB = *i;
+ for (BasicBlock::iterator j = BB->begin(), je = BB->end(); j!=je; ++j)
+ if (CallInst *CI = dyn_cast<CallInst>(j))
+ if (CI->getCalledFunction()->getName() == "__llvm_thread_join")
+ if (!containsJoin(AllJoins, CI))
+ AllJoins.push_back(CI);
+ }
+
+ // Add all unique joins to the coallesced joins block
+ for (std::vector<CallInst*>::iterator c = AllJoins.begin(),
+ ce = AllJoins.end(); c != ce; ++c)
+ SumJoins->getInstList().push_back((*c)->clone());
+
+ // Terminate the Joins block with a branch to the post-join code
+ Instruction *TI = PRJoinBlocks[0]->getTerminator();
+ if (TI)
+ SumJoins->getInstList().push_back(TI->clone());
+
+ // Branch to the all-joins block from the main code
+ for (std::vector<BasicBlock*>::iterator j = PRJoinBlocks.begin(),
+ je = PRJoinBlocks.end(); j != je; ++j) {
+ std::vector<User*> JUsers((*j)->use_begin(), (*j)->use_end());
+ for (std::vector<User*>::iterator u = JUsers.begin(), ue = JUsers.end();
+ u != ue; ++u)
+ if (Instruction *I = dyn_cast<Instruction>(*u))
+ if (contains(PR->getBlocks(), I->getParent()))
+ I->replaceUsesOfWith(*j, SumJoins);
+ }
}
}
@@ -115,7 +175,6 @@
// Remove pbr
ParaBrInst *pbr = dyn_cast<ParaBrInst>(PS->getHeader()->getTerminator());
- //assert(pbr && "Terminator of ParaSeq header is not a pbr");
if (pbr) {
new BranchInst(pbr->getSuccessor(0), pbr);
pbr->getParent()->getInstList().erase(pbr);
@@ -125,4 +184,12 @@
Function* LowerParaBr::getJoinIntrinsic(Module *M) {
return M->getOrInsertFunction("llvm.join", Type::VoidTy,
PointerType::get(Type::SByteTy), 0);
+}
+
+/// getFuncThreadJoin -
+///
+Function* LowerParaBr::getFuncThreadJoin(Module &M) {
+ // void __llvm_thread_join(int);
+ return M.getOrInsertFunction("__llvm_thread_join", Type::VoidTy, Type::IntTy,
+ 0);
}
More information about the llvm-commits
mailing list