[llvm] [OpenMPIRBuilder] Add support for target workshare loops (PR #73360)
Johannes Doerfert via llvm-commits
llvm-commits at lists.llvm.org
Tue Nov 28 11:14:38 PST 2023
================
@@ -2681,11 +2681,255 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
}
+// Returns an LLVM function to call for executing an OpenMP static worksharing
+// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
+// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
+static FunctionCallee
+getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
+ OpenMPIRBuilder::WorksharingLoopType LoopType) {
+ unsigned Bitwidth = Ty->getIntegerBitWidth();
+ Module &M = OMPBuilder->M;
+ switch (LoopType) {
+ case OpenMPIRBuilder::WorksharingLoopType::ForStaticLoop:
+ if (Bitwidth == 32)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
+ if (Bitwidth == 64)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
+ break;
+ case OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop:
+ if (Bitwidth == 32)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
+ if (Bitwidth == 64)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
+ break;
+ case OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop:
+ if (Bitwidth == 32)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
+ if (Bitwidth == 64)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
+ break;
+ }
+ if (Bitwidth != 32 && Bitwidth != 64)
+ llvm_unreachable("unknown OpenMP loop iterator bitwidth");
+ return FunctionCallee();
+}
+
+// Inserts a call to proper OpenMP Device RTL function which handles
+// loop worksharing.
+static void createTargetLoopWorkshareCall(
+ OpenMPIRBuilder *OMPBuilder, OpenMPIRBuilder::WorksharingLoopType LoopType,
+ BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
+ Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
+ Type *TripCountTy = TripCount->getType();
+ Module &M = OMPBuilder->M;
+ IRBuilder<> &Builder = OMPBuilder->Builder;
+ FunctionCallee RTLFn =
+ getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
+ SmallVector<Value *, 8> RealArgs;
+ RealArgs.push_back(Ident);
+ /*loop body func*/
+ RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
+ /*loop body args*/
+ RealArgs.push_back(LoopBodyArg);
+ /*num of iters*/
+ RealArgs.push_back(TripCount);
+ if (LoopType == OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop) {
+ /*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+ ? Builder.getInt32(0)
+ : Builder.getInt64(0));
+ Builder.CreateCall(RTLFn, RealArgs);
+ return;
+ }
+ FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
+ Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
+ Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
+
+ /*num of threads*/ RealArgs.push_back(
+ Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
+ if (LoopType ==
+ OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop) {
+ /*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+ ? Builder.getInt32(0)
+ : Builder.getInt64(0));
+ }
+ /*thread chunk */ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+ ? Builder.getInt32(1)
+ : Builder.getInt64(1));
+
+ Builder.CreateCall(RTLFn, RealArgs);
+}
+
+static void
+workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
+ CanonicalLoopInfo *CLI, Value *Ident,
+ Function &OutlinedFn, Type *ParallelTaskPtr,
+ const SmallVector<Instruction *, 4> &ToBeDeleted,
+ OpenMPIRBuilder::WorksharingLoopType LoopType) {
+ IRBuilder<> &Builder = OMPIRBuilder->Builder;
+ BasicBlock *Preheader = CLI->getPreheader();
+ Value *TripCount = CLI->getTripCount();
+
+ // After loop body outling, the loop body contains only set up
+ // of loop body argument structure and the call to the outlined
+ // loop body function. Firstly, we need to move setup of loop body args
+ // into loop preheader.
+ Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
+ CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
+
+ // The next step is to remove the whole loop. We do not it need anymore.
+ // That's why make an unconditional branch from loop preheader to loop
+ // exit block
+ Builder.restoreIP({Preheader, Preheader->end()});
+ Preheader->getTerminator()->eraseFromParent();
+ Builder.CreateBr(CLI->getExit());
+
+ // Delete dead loop blocks
+ OpenMPIRBuilder::OutlineInfo CleanUpInfo;
+ SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
+ SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
+ CleanUpInfo.EntryBB = CLI->getHeader();
+ CleanUpInfo.ExitBB = CLI->getExit();
+ CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
+ DeleteDeadBlocks(BlocksToBeRemoved);
+
+ // Find the instruction which corresponds to loop body argument structure
+ // and remove the call to loop body function instruction.
+ Value *LoopBodyArg;
+ for (auto instIt = Preheader->begin(); instIt != Preheader->end(); ++instIt) {
+ if (CallInst *CallInstruction = dyn_cast<CallInst>(instIt)) {
+ if (CallInstruction->getCalledFunction() == &OutlinedFn) {
+ // Check in case no argument structure has been passed.
+ if (CallInstruction->arg_size() > 1)
+ LoopBodyArg = CallInstruction->getArgOperand(1);
+ else
+ LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
+ CallInstruction->eraseFromParent();
+ break;
+ }
+ }
+ }
+
+ createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
+ LoopBodyArg, ParallelTaskPtr, TripCount,
+ OutlinedFn);
+
+ for (auto &ToBeDeletedItem : ToBeDeleted)
+ ToBeDeletedItem->eraseFromParent();
+ CLI->invalidate();
+}
+
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
+ DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
+ OpenMPIRBuilder::WorksharingLoopType LoopType) {
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
+ Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+
+ OutlineInfo OI;
+ OI.OuterAllocaBB = CLI->getPreheader();
+ Function *OuterFn = CLI->getPreheader()->getParent();
+
+ // Instructions which need to be deleted at the end of code generation
+ SmallVector<Instruction *, 4> ToBeDeleted;
+
+ OI.OuterAllocaBB = AllocaIP.getBlock();
+
+ // Mark the body loop as region which needs to be extracted
+ OI.EntryBB = CLI->getBody();
+ OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
+ "omp.prelatch", true);
+
+ // Prepare loop body for extraction
+ Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
+
+ // Insert new loop counter variable which will be used only in loop
+ // body.
+ AllocaInst *newLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
+ Instruction *newLoopCntLoad =
+ Builder.CreateLoad(CLI->getIndVarType(), newLoopCnt);
+ // New loop counter instructions are redundant in the loop preheader when
+ // code generation for workshare loop is finshed. That's why mark them as
+ // ready for deletion.
+ ToBeDeleted.push_back(newLoopCntLoad);
+ ToBeDeleted.push_back(newLoopCnt);
+
+ // Analyse loop body region. Find all input variables which are used inside
+ // loop body region.
+ SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
+ SmallVector<BasicBlock *, 32> Blocks;
+ OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+ SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
+ ParallelRegionBlockSet.end());
+
+ CodeExtractorAnalysisCache CEAC(*OuterFn);
+ CodeExtractor Extractor(Blocks,
+ /* DominatorTree */ nullptr,
+ /* AggregateArgs */ true,
+ /* BlockFrequencyInfo */ nullptr,
+ /* BranchProbabilityInfo */ nullptr,
+ /* AssumptionCache */ nullptr,
+ /* AllowVarArgs */ true,
+ /* AllowAlloca */ true,
+ /* AllocationBlock */ CLI->getPreheader(),
+ /* Suffix */ ".omp_wsloop",
+ /* AggrArgsIn0AddrSpace */ true);
+
+ BasicBlock *CommonExit = nullptr;
+ SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
+
+ // Find allocas outside the loop body region which are used inside loop
+ // body
+ Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
+
+ // We need to model loop body region as the function f(cnt, loop_arg).
+ // That's why we replace loop induction variable by the new counter
+ // which will be one of loop body function argument
+ std::vector<User *> Users(CLI->getIndVar()->user_begin(),
+ CLI->getIndVar()->user_end());
+ for (User *use : Users) {
+ if (Instruction *inst = dyn_cast<Instruction>(use)) {
+ if (ParallelRegionBlockSet.count(inst->getParent())) {
+ inst->replaceUsesOfWith(CLI->getIndVar(), newLoopCntLoad);
----------------
jdoerfert wrote:
Style, `Use`, `Inst`, ...
https://github.com/llvm/llvm-project/pull/73360
More information about the llvm-commits
mailing list