[llvm] [OpenMPIRBuilder] Add support for target workshare loops (PR #73360)
Dominik Adamski via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 30 02:50:32 PST 2023
================
@@ -2681,11 +2681,255 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
}
+// Returns an LLVM function to call for executing an OpenMP static worksharing
+// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
+// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
+static FunctionCallee
+getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
+ OpenMPIRBuilder::WorksharingLoopType LoopType) {
+ unsigned Bitwidth = Ty->getIntegerBitWidth();
+ Module &M = OMPBuilder->M;
+ switch (LoopType) {
+ case OpenMPIRBuilder::WorksharingLoopType::ForStaticLoop:
+ if (Bitwidth == 32)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
+ if (Bitwidth == 64)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
+ break;
+ case OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop:
+ if (Bitwidth == 32)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
+ if (Bitwidth == 64)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
+ break;
+ case OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop:
+ if (Bitwidth == 32)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
+ if (Bitwidth == 64)
+ return OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
+ break;
+ }
+ if (Bitwidth != 32 && Bitwidth != 64)
+ llvm_unreachable("unknown OpenMP loop iterator bitwidth");
+ return FunctionCallee();
+}
+
+// Inserts a call to proper OpenMP Device RTL function which handles
+// loop worksharing.
+static void createTargetLoopWorkshareCall(
+ OpenMPIRBuilder *OMPBuilder, OpenMPIRBuilder::WorksharingLoopType LoopType,
+ BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
+ Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
+ Type *TripCountTy = TripCount->getType();
+ Module &M = OMPBuilder->M;
+ IRBuilder<> &Builder = OMPBuilder->Builder;
+ FunctionCallee RTLFn =
+ getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
+ SmallVector<Value *, 8> RealArgs;
+ RealArgs.push_back(Ident);
+ /*loop body func*/
+ RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
+ /*loop body args*/
+ RealArgs.push_back(LoopBodyArg);
+ /*num of iters*/
+ RealArgs.push_back(TripCount);
+ if (LoopType == OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop) {
+ /*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+ ? Builder.getInt32(0)
+ : Builder.getInt64(0));
+ Builder.CreateCall(RTLFn, RealArgs);
+ return;
+ }
+ FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
+ M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
+ Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
+ Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
+
+ /*num of threads*/ RealArgs.push_back(
+ Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
+ if (LoopType ==
+ OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop) {
+ /*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+ ? Builder.getInt32(0)
+ : Builder.getInt64(0));
+ }
+ /*thread chunk */ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+ ? Builder.getInt32(1)
+ : Builder.getInt64(1));
----------------
DominikAdamski wrote:
Done
https://github.com/llvm/llvm-project/pull/73360
More information about the llvm-commits
mailing list