[clang] [llvm] [mlir] [OMPIRBuilder] - Handle dependencies in `createTarget` (PR #93977)
Michael Kruse via cfe-commits
cfe-commits at lists.llvm.org
Fri Jun 7 11:05:41 PDT 2024
================
@@ -5229,13 +5362,288 @@ static void emitTargetOutlinedFunction(
OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, true,
OutlinedFn, OutlinedFnID);
}
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
+ Function *OutlinedFn, Value *OutlinedFnID,
+ EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
+ Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP,
+ SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
+ bool HasNoWait) {
+
+ // When we arrive at this function, the target region itself has been
+ // outlined into the function OutlinedFn.
+ // So at ths point, for
+ // --------------------------------------------------
+ // void user_code_that_offloads(...) {
+ // omp target depend(..) map(from:a) map(to:b, c)
+ // a = b + c
+ // }
+ //
+ // --------------------------------------------------
+ //
+ // we have
+ //
+ // --------------------------------------------------
+ //
+ // void user_code_that_offloads(...) {
+ // %.offload_baseptrs = alloca [3 x ptr], align 8
+ // %.offload_ptrs = alloca [3 x ptr], align 8
+ // %.offload_mappers = alloca [3 x ptr], align 8
+ // ;; target region has been outlined and now we need to
+ // ;; offload to it via a target task.
+ // }
+ // void outlined_device_function(ptr a, ptr b, ptr c) {
+ // *a = *b + *c
+ // }
+ //
+ // We have to now do the following
+ // (i) Make an offloading call to outlined_device_function using the OpenMP
+ // RTL. See 'kernel_launch_function' in the pseudo code below. This is
+ // emitted by emitKernelLaunch
+ // (ii) Create a task entry point function that calls kernel_launch_function
+ // and is the entry point for the target task. See
+ // '@.omp_target_task_proxy_func in the pseudocode below.
+ // (iii) Create a task with the task entry point created in (ii)
+ //
+ // That is we create the following
+ //
+ // void user_code_that_offloads(...) {
+ // %.offload_baseptrs = alloca [3 x ptr], align 8
+ // %.offload_ptrs = alloca [3 x ptr], align 8
+ // %.offload_mappers = alloca [3 x ptr], align 8
+ //
+ // %structArg = alloca { ptr, ptr, ptr }, align 8
+ // %strucArg[0] = %.offload_baseptrs
+ // %strucArg[1] = %.offload_ptrs
+ // %strucArg[2] = %.offload_mappers
+ // proxy_target_task = @__kmpc_omp_task_alloc(...,
+ // @.omp_target_task_proxy_func)
+ // memcpy(proxy_target_task->shareds, %structArg, sizeof(structArg))
+ // dependencies_array = ...
+ // ;; if nowait not present
+ // call @__kmpc_omp_wait_deps(..., dependencies_array)
+ // call @__kmpc_omp_task_begin_if0(...)
+ // call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
+ // %proxy_target_task) call @__kmpc_omp_task_complete_if0(...)
+ // }
+ //
+ // define internal void @.omp_target_task_proxy_func(i32 %thread.id,
+ // ptr %task) {
+ // %structArg = alloca {ptr, ptr, ptr}
+ // %shared_data = load (getelementptr %task, 0, 0)
+ // mempcy(%structArg, %shared_data, sizeof(structArg))
+ // kernel_launch_function(%thread.id, %structArg)
+ // }
+ //
+ // We need the proxy function because the signature of the task entry point
+ // expected by kmpc_omp_task is always the same and will be different from
+ // that of the kernel_launch function.
+ //
+ // kernel_launch_function is generated by emitKernelLaunch and has the
+ // always_inline attribute. void kernel_launch_function(thread_id,
+ // structArg)
+ // alwaysinline {
+ // %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
+ // offload_baseptrs = load(getelementptr structArg, 0, 0)
+ // offload_ptrs = load(getelementptr structArg, 0, 1)
+ // offload_mappers = load(getelementptr structArg, 0, 2)
+ // ; setup kernel_args using offload_baseptrs, offload_ptrs and
+ // ; offload_mappers
+ // call i32 @__tgt_target_kernel(...,
+ // outlined_device_function,
+ // ptr %kernel_args)
+ // }
+ // void outlined_device_function(ptr a, ptr b, ptr c) {
+ // *a = *b + *c
+ // }
+ //
+ BasicBlock *TargetTaskBodyBB =
+ splitBB(Builder, /*CreateBranch=*/true, "target.task.body");
+ BasicBlock *TargetTaskAllocaBB =
+ splitBB(Builder, /*CreateBranch=*/true, "target.task.alloca");
+
+ InsertPointTy TargetTaskAllocaIP =
+ InsertPointTy(TargetTaskAllocaBB, TargetTaskAllocaBB->begin());
+ InsertPointTy TargetTaskBodyIP =
+ InsertPointTy(TargetTaskBodyBB, TargetTaskBodyBB->begin());
+
+ OutlineInfo OI;
+ OI.EntryBB = TargetTaskAllocaBB;
+ OI.OuterAllocaBB = AllocaIP.getBlock();
+
+ // Add the thread ID argument.
+ std::stack<Instruction *> ToBeDeleted;
+ OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+ Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", false));
+
+ Builder.restoreIP(TargetTaskBodyIP);
+
+ // emitKernelLaunch makes the necessary runtime call to offload the kernel.
+ // We then outline all that code into a separate function
+ // ('kernel_launch_function' in the pseudo code above). This function is then
+ // called by the target task proxy function (see
+ // '@.omp_target_task_proxy_func' in the pseudo code above)
+ // "@.omp_target_task_proxy_func' is generated by emitProxyTaskFunction
+ Builder.restoreIP(emitKernelLaunch(Builder, OutlinedFn, OutlinedFnID,
+ EmitTargetCallFallbackCB, Args, DeviceID,
+ RTLoc, TargetTaskAllocaIP));
+
+ OI.ExitBB = Builder.saveIP().getBlock();
+ OI.PostOutlineCB = [this, ToBeDeleted, Dependencies,
+ HasNoWait](Function &OutlinedFn) mutable {
+ assert(OutlinedFn.getNumUses() == 1 &&
+ "there must be a single user for the outlined function");
+
+ CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+ bool HasShareds = StaleCI->arg_size() > 1;
+
+ LLVM_DEBUG(dbgs() << "StaleCI in PostOutlineCB in emitTargetTask = "
+ << *StaleCI << "\n");
+ LLVM_DEBUG(dbgs() << "Module in PostOutlineCB in emitTargetTask = "
+ << *(StaleCI->getParent()->getParent()->getParent())
+ << "\n");
+
+ Function *ProxyFn = emitProxyTaskFunction(*this, Builder, StaleCI);
+
+ LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
+ << "\n");
+
+ Builder.SetInsertPoint(StaleCI);
+
+ // Gather the arguments for emitting the runtime call for
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr =
+ getOrCreateSrcLocStr(LocationDescription(Builder), SrcLocStrSize);
+ Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+
+ // @__kmpc_omp_task_alloc
+ Function *TaskAllocFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
+
+ // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
+ // call.
+ Value *ThreadID = getOrCreateThreadID(Ident);
+
+ // Argument - `sizeof_kmp_task_t` (TaskSize)
+ // Tasksize refers to the size in bytes of kmp_task_t data structure
+ // including private vars accessed in task.
+ // TODO: add kmp_task_t_with_privates (privates)
+ Value *TaskSize = Builder.getInt64(
+ divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
+
+ // Argument - `sizeof_shareds` (SharedsSize)
+ // SharedsSize refers to the shareds array size in the kmp_task_t data
+ // structure.
+ Value *SharedsSize = Builder.getInt64(0);
+ if (HasShareds) {
+ AllocaInst *ArgStructAlloca =
+ dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
+ assert(ArgStructAlloca &&
+ "Unable to find the alloca instruction corresponding to arguments "
+ "for extracted function");
+ StructType *ArgStructType =
+ dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
+ assert(ArgStructType && "Unable to find struct type corresponding to "
+ "arguments for extracted function");
+ SharedsSize =
+ Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
+ }
+
+ // Argument - `flags`
+ // Task is tied iff (Flags & 1) == 1.
+ // Task is untied iff (Flags & 1) == 0.
+ // Task is final iff (Flags & 2) == 2.
+ // Task is not final iff (Flags & 2) == 0.
+ // A target task is not final and is untied.
+ Value *Flags = Builder.getInt32(0);
+
+ // Emit the @__kmpc_omp_task_alloc runtime call
+ // The runtime call returns a pointer to an area where the task captured
+ // variables must be copied before the task is run (TaskData)
+ CallInst *TaskData = Builder.CreateCall(
+ TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
+ /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
+ /*task_func=*/ProxyFn});
+
+ if (HasShareds) {
+ Value *Shareds = StaleCI->getArgOperand(1);
+ Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
+ Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
+ Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
+ SharedsSize);
+ }
+
+ Value *DepArray = emitDepArray(*this, Dependencies);
+
+ // ---------------------------------------------------------------
+ // V5.2 13.8 target construct
+ // If the nowait clause is present, execution of the target task
+ // may be deferred. If the nowait clause is not present, the target task is
+ // an included task.
+ // ---------------------------------------------------------------
+ // The above means that the lack of a nowait on the target construct
+ // translates to '#pragma omp task if(0)'
+ if (!HasNoWait) {
+ if (DepArray) {
+ Function *TaskWaitFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
+ Builder.CreateCall(
+ TaskWaitFn,
+ {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray,
+ ConstantInt::get(Builder.getInt32Ty(), 0),
+ ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
+ }
+ // Included task.
+ Function *TaskBeginFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
+ Function *TaskCompleteFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
+ Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
+ CallInst *CI = nullptr;
+ if (HasShareds)
+ CI = Builder.CreateCall(ProxyFn, {ThreadID, TaskData});
+ else
+ CI = Builder.CreateCall(ProxyFn, {ThreadID});
+ CI->setDebugLoc(StaleCI->getDebugLoc());
+ Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
+ } else if (DepArray) {
+ // HasNoWait - meaning the task may be deferred. Call
+ // __kmpc_omp_task_with_deps if there are dependencies,
+ // else call __kmpc_omp_task
+ Function *TaskFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
+ Builder.CreateCall(
+ TaskFn,
+ {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
+ DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
+ ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
+ } else {
+ // Emit the @__kmpc_omp_task runtime call to spawn the task
+ Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
+ Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
+ }
-static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
- OpenMPIRBuilder::InsertPointTy AllocaIP,
- Function *OutlinedFn, Constant *OutlinedFnID,
- int32_t NumTeams, int32_t NumThreads,
- SmallVectorImpl<Value *> &Args,
- OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB) {
+ StaleCI->eraseFromParent();
+ while (!ToBeDeleted.empty()) {
+ ToBeDeleted.top()->eraseFromParent();
+ ToBeDeleted.pop();
+ }
----------------
Meinersbur wrote:
```suggestion
llvm::for_each(ToBeDeleted, [](Instruction *I) { I->ereaseFromParent(); } );
```
https://github.com/llvm/llvm-project/pull/93977
More information about the cfe-commits
mailing list