[llvm] [OpenMP][OMPIRBuilder] Add support to omp target parallel (PR #67000)
Johannes Doerfert via llvm-commits
llvm-commits at lists.llvm.org
Mon Oct 23 11:04:43 PDT 2023
================
@@ -1126,6 +1133,185 @@ void OpenMPIRBuilder::emitCancelationCheckImpl(Value *CancelFlag,
Builder.SetInsertPoint(NonCancellationBlock, NonCancellationBlock->begin());
}
+// Callback used to create OpenMP runtime calls to support
+// omp parallel clause for the device.
+// We need to use this callback to replace call to the OutlinedFn in OuterFn
+// by the call to the OpenMP DeviceRTL runtime function (kmpc_parallel_51)
+static void
+targetParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
+ Function *OuterFn, Value *Ident, Value *IfCondition,
+ Value *NumThreads, Instruction *PrivTID,
+ AllocaInst *PrivTIDAddr, Value *ThreadID,
+ const SmallVector<Instruction *, 4> &ToBeDeleted) {
+ // Add some known attributes.
+ Module &M = OMPIRBuilder->M;
+ IRBuilder<> &Builder = OMPIRBuilder->Builder;
+ OutlinedFn.addParamAttr(0, Attribute::NoAlias);
+ OutlinedFn.addParamAttr(1, Attribute::NoAlias);
+ OutlinedFn.addParamAttr(0, Attribute::NoUndef);
+ OutlinedFn.addParamAttr(1, Attribute::NoUndef);
+ OutlinedFn.addFnAttr(Attribute::NoUnwind);
+
+ assert(OutlinedFn.arg_size() >= 2 &&
+ "Expected at least tid and bounded tid as arguments");
+ unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
+
+ CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
+ assert(CI && "Expected call instruction to outlined function");
+ CI->getParent()->setName("omp_parallel");
+ // Replace direct call to the outlined function by the call to
+ // __kmpc_parallel_51
+ Builder.SetInsertPoint(CI);
+
+ // Build call __kmpc_parallel_51
+ auto PtrTy = Type::getInt8PtrTy(M.getContext());
+ Value *Void = ConstantPointerNull::get(PtrTy);
+ // Add alloca for kernel args. Put this instruction at the beginning
+ // of the function.
+ OpenMPIRBuilder ::InsertPointTy CurrentIP = Builder.saveIP();
+ Builder.SetInsertPoint(&OuterFn->front(),
+ OuterFn->front().getFirstInsertionPt());
+ AllocaInst *ArgsAlloca =
+ Builder.CreateAlloca(ArrayType::get(PtrTy, NumCapturedVars));
+ Value *Args =
+ Builder.CreatePointerCast(ArgsAlloca, Type::getInt8PtrTy(M.getContext()));
+ Builder.restoreIP(CurrentIP);
+ // Store captured vars which are used by kmpc_parallel_51
+ if (NumCapturedVars) {
+ for (unsigned Idx = 0; Idx < NumCapturedVars; Idx++) {
+ Value *V = *(CI->arg_begin() + 2 + Idx);
+ Value *StoreAddress = Builder.CreateConstInBoundsGEP2_64(
+ ArrayType::get(PtrTy, NumCapturedVars), Args, 0, Idx);
+ Builder.CreateStore(V, StoreAddress);
+ }
+ }
+ Value *Cond = IfCondition ? Builder.CreateSExtOrTrunc(
+ IfCondition, Type::getInt32Ty(M.getContext()))
+ : Builder.getInt32(1);
+ Value *Parallel51CallArgs[] = {
+ /* identifier*/ Ident,
+ /* global thread num*/ ThreadID,
+ /* if expression */ Cond, NumThreads ? NumThreads : Builder.getInt32(-1),
+ /* Proc bind */ Builder.getInt32(-1),
+ /* outlined function */
+ Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr), Void,
+ Args, Builder.getInt64(NumCapturedVars)};
+
+ SmallVector<Value *, 16> RealArgs;
+ RealArgs.append(std::begin(Parallel51CallArgs), std::end(Parallel51CallArgs));
+ FunctionCallee RTLFn =
+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_parallel_51);
+
+ Builder.CreateCall(RTLFn, RealArgs);
+
+ LLVM_DEBUG(dbgs() << "With kmpc_parallel_51 placed: "
+ << *Builder.GetInsertBlock()->getParent() << "\n");
+
+ // Initialize the local TID stack location with the argument value.
+ Builder.SetInsertPoint(PrivTID);
+ Function::arg_iterator OutlinedAI = OutlinedFn.arg_begin();
+ Builder.CreateStore(Builder.CreateLoad(OMPIRBuilder->Int32, OutlinedAI),
+ PrivTIDAddr);
+
+ // Remove redundant call to the outlined function.
+ CI->eraseFromParent();
+
+ for (Instruction *I : ToBeDeleted) {
+ I->eraseFromParent();
+ }
+}
+
+// Callback used to create OpenMP runtime calls to support
+// omp parallel clause for the host.
+// We need to use this callback to replace call to the OutlinedFn in OuterFn
+// by the call to the OpenMP host runtime function ( __kmpc_fork_call[_if])
+static void
+hostParallelCallback(OpenMPIRBuilder *OMPIRBuilder, Function &OutlinedFn,
+ Function *OuterFn, Value *Ident, Value *IfCondition,
+ Instruction *PrivTID, AllocaInst *PrivTIDAddr,
+ const SmallVector<Instruction *, 4> &ToBeDeleted) {
+ Module &M = OMPIRBuilder->M;
+ IRBuilder<> &Builder = OMPIRBuilder->Builder;
+ FunctionCallee RTLFn;
+ if (IfCondition) {
+ RTLFn =
+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call_if);
+ } else {
+ RTLFn =
+ OMPIRBuilder->getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_call);
+ }
+ if (auto *F = dyn_cast<llvm::Function>(RTLFn.getCallee())) {
+ if (!F->hasMetadata(llvm::LLVMContext::MD_callback)) {
+ llvm::LLVMContext &Ctx = F->getContext();
+ MDBuilder MDB(Ctx);
+ // Annotate the callback behavior of the __kmpc_fork_call:
+ // - The callback callee is argument number 2 (microtask).
+ // - The first two arguments of the callback callee are unknown (-1).
+ // - All variadic arguments to the __kmpc_fork_call are passed to the
+ // callback callee.
+ F->addMetadata(
+ llvm::LLVMContext::MD_callback,
+ *llvm::MDNode::get(
+ Ctx, {MDB.createCallbackEncoding(2, {-1, -1},
+ /* VarArgsArePassed */ true)}));
+ }
+ }
+ // Add some known attributes.
+ OutlinedFn.addParamAttr(0, Attribute::NoAlias);
+ OutlinedFn.addParamAttr(1, Attribute::NoAlias);
+ OutlinedFn.addFnAttr(Attribute::NoUnwind);
+
+ assert(OutlinedFn.arg_size() >= 2 &&
+ "Expected at least tid and bounded tid as arguments");
+ unsigned NumCapturedVars = OutlinedFn.arg_size() - /* tid & bounded tid */ 2;
+
+ CallInst *CI = cast<CallInst>(OutlinedFn.user_back());
+ CI->getParent()->setName("omp_parallel");
+ Builder.SetInsertPoint(CI);
+
+ // Build call __kmpc_fork_call[_if](Ident, n, microtask, var1, .., varn);
+ Value *ForkCallArgs[] = {
+ Ident, Builder.getInt32(NumCapturedVars),
+ Builder.CreateBitCast(&OutlinedFn, OMPIRBuilder->ParallelTaskPtr)};
+
+ SmallVector<Value *, 16> RealArgs;
+ RealArgs.append(std::begin(ForkCallArgs), std::end(ForkCallArgs));
+ if (IfCondition) {
+ Value *Cond = Builder.CreateSExtOrTrunc(IfCondition,
+ Type::getInt32Ty(M.getContext()));
+ RealArgs.push_back(Cond);
+ }
+ RealArgs.append(CI->arg_begin() + /* tid & bound tid */ 2, CI->arg_end());
+
+ // __kmpc_fork_call_if always expects a void ptr as the last argument
+ // If there are no arguments, pass a null pointer.
+ auto PtrTy = Type::getInt8PtrTy(M.getContext());
+ if (IfCondition && NumCapturedVars == 0) {
+ Value *Void = ConstantPointerNull::get(PtrTy);
+ RealArgs.push_back(Void);
+ }
+ if (IfCondition && RealArgs.back()->getType() != PtrTy)
+ RealArgs.back() = Builder.CreateBitCast(RealArgs.back(), PtrTy);
----------------
jdoerfert wrote:
This does not look sound, but you just moved it, it's ok.
https://github.com/llvm/llvm-project/pull/67000
More information about the llvm-commits
mailing list