[clang] [llvm] [mlir] [OMPIRBuilder] - Handle dependencies in `createTarget` (PR #93977)

Michael Kruse via cfe-commits cfe-commits at lists.llvm.org
Fri Jun 7 11:05:41 PDT 2024


================
@@ -5229,13 +5362,288 @@ static void emitTargetOutlinedFunction(
   OMPBuilder.emitTargetRegionFunction(EntryInfo, GenerateOutlinedFunction, true,
                                       OutlinedFn, OutlinedFnID);
 }
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitTargetTask(
+    Function *OutlinedFn, Value *OutlinedFnID,
+    EmitFallbackCallbackTy EmitTargetCallFallbackCB, TargetKernelArgs &Args,
+    Value *DeviceID, Value *RTLoc, OpenMPIRBuilder::InsertPointTy AllocaIP,
+    SmallVector<llvm::OpenMPIRBuilder::DependData> &Dependencies,
+    bool HasNoWait) {
+
+  // When we arrive at this function, the target region itself has been
+  // outlined into the function OutlinedFn.
+  // So at ths point, for
+  // --------------------------------------------------
+  //   void user_code_that_offloads(...) {
+  //     omp target depend(..) map(from:a) map(to:b, c)
+  //        a = b + c
+  //   }
+  //
+  // --------------------------------------------------
+  //
+  // we have
+  //
+  // --------------------------------------------------
+  //
+  //   void user_code_that_offloads(...) {
+  //     %.offload_baseptrs = alloca [3 x ptr], align 8
+  //     %.offload_ptrs = alloca [3 x ptr], align 8
+  //     %.offload_mappers = alloca [3 x ptr], align 8
+  //     ;; target region has been outlined and now we need to
+  //     ;; offload to it via a target task.
+  //   }
+  //   void outlined_device_function(ptr a, ptr b, ptr c) {
+  //     *a = *b + *c
+  //   }
+  //
+  // We have to now do the following
+  // (i)   Make an offloading call to outlined_device_function using the OpenMP
+  //       RTL. See 'kernel_launch_function' in the pseudo code below. This is
+  //       emitted by emitKernelLaunch
+  // (ii)  Create a task entry point function that calls kernel_launch_function
+  //       and is the entry point for the target task. See
+  //       '@.omp_target_task_proxy_func in the pseudocode below.
+  // (iii) Create a task with the task entry point created in (ii)
+  //
+  // That is we create the following
+  //
+  //   void user_code_that_offloads(...) {
+  //     %.offload_baseptrs = alloca [3 x ptr], align 8
+  //     %.offload_ptrs = alloca [3 x ptr], align 8
+  //     %.offload_mappers = alloca [3 x ptr], align 8
+  //
+  //     %structArg = alloca { ptr, ptr, ptr }, align 8
+  //     %strucArg[0] = %.offload_baseptrs
+  //     %strucArg[1] = %.offload_ptrs
+  //     %strucArg[2] = %.offload_mappers
+  //     proxy_target_task = @__kmpc_omp_task_alloc(...,
+  //                                               @.omp_target_task_proxy_func)
+  //     memcpy(proxy_target_task->shareds, %structArg, sizeof(structArg))
+  //     dependencies_array = ...
+  //     ;; if nowait not present
+  //     call @__kmpc_omp_wait_deps(..., dependencies_array)
+  //     call @__kmpc_omp_task_begin_if0(...)
+  //     call @ @.omp_target_task_proxy_func(i32 thread_id, ptr
+  //     %proxy_target_task) call @__kmpc_omp_task_complete_if0(...)
+  //   }
+  //
+  //   define internal void @.omp_target_task_proxy_func(i32 %thread.id,
+  //                                                     ptr %task) {
+  //       %structArg = alloca {ptr, ptr, ptr}
+  //       %shared_data = load (getelementptr %task, 0, 0)
+  //       mempcy(%structArg, %shared_data, sizeof(structArg))
+  //       kernel_launch_function(%thread.id, %structArg)
+  //   }
+  //
+  //   We need the proxy function because the signature of the task entry point
+  //   expected by kmpc_omp_task is always the same and will be different from
+  //   that of the kernel_launch function.
+  //
+  //   kernel_launch_function is generated by emitKernelLaunch and has the
+  //   always_inline attribute. void kernel_launch_function(thread_id,
+  //                                                        structArg)
+  //                                                        alwaysinline {
+  //       %kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
+  //       offload_baseptrs = load(getelementptr structArg, 0, 0)
+  //       offload_ptrs = load(getelementptr structArg, 0, 1)
+  //       offload_mappers = load(getelementptr structArg, 0, 2)
+  //       ; setup kernel_args using offload_baseptrs, offload_ptrs and
+  //       ; offload_mappers
+  //       call i32 @__tgt_target_kernel(...,
+  //                                     outlined_device_function,
+  //                                     ptr %kernel_args)
+  //   }
+  //   void outlined_device_function(ptr a, ptr b, ptr c) {
+  //      *a = *b + *c
+  //   }
+  //
+  BasicBlock *TargetTaskBodyBB =
+      splitBB(Builder, /*CreateBranch=*/true, "target.task.body");
+  BasicBlock *TargetTaskAllocaBB =
+      splitBB(Builder, /*CreateBranch=*/true, "target.task.alloca");
+
+  InsertPointTy TargetTaskAllocaIP =
+      InsertPointTy(TargetTaskAllocaBB, TargetTaskAllocaBB->begin());
+  InsertPointTy TargetTaskBodyIP =
+      InsertPointTy(TargetTaskBodyBB, TargetTaskBodyBB->begin());
+
+  OutlineInfo OI;
+  OI.EntryBB = TargetTaskAllocaBB;
+  OI.OuterAllocaBB = AllocaIP.getBlock();
+
+  // Add the thread ID argument.
+  std::stack<Instruction *> ToBeDeleted;
+  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+      Builder, AllocaIP, ToBeDeleted, TargetTaskAllocaIP, "global.tid", false));
+
+  Builder.restoreIP(TargetTaskBodyIP);
+
+  // emitKernelLaunch makes the necessary runtime call to offload the kernel.
+  // We then outline all that code into a separate function
+  // ('kernel_launch_function' in the pseudo code above). This function is then
+  // called by the target task proxy function (see
+  // '@.omp_target_task_proxy_func' in the pseudo code above)
+  // "@.omp_target_task_proxy_func' is generated by emitProxyTaskFunction
+  Builder.restoreIP(emitKernelLaunch(Builder, OutlinedFn, OutlinedFnID,
+                                     EmitTargetCallFallbackCB, Args, DeviceID,
+                                     RTLoc, TargetTaskAllocaIP));
+
+  OI.ExitBB = Builder.saveIP().getBlock();
+  OI.PostOutlineCB = [this, ToBeDeleted, Dependencies,
+                      HasNoWait](Function &OutlinedFn) mutable {
+    assert(OutlinedFn.getNumUses() == 1 &&
+           "there must be a single user for the outlined function");
+
+    CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+    bool HasShareds = StaleCI->arg_size() > 1;
+
+    LLVM_DEBUG(dbgs() << "StaleCI in PostOutlineCB in emitTargetTask = "
+                      << *StaleCI << "\n");
+    LLVM_DEBUG(dbgs() << "Module in PostOutlineCB in emitTargetTask = "
+                      << *(StaleCI->getParent()->getParent()->getParent())
+                      << "\n");
+
+    Function *ProxyFn = emitProxyTaskFunction(*this, Builder, StaleCI);
+
+    LLVM_DEBUG(dbgs() << "Proxy task entry function created: " << *ProxyFn
+                      << "\n");
+
+    Builder.SetInsertPoint(StaleCI);
+
+    // Gather the arguments for emitting the runtime call for
+    uint32_t SrcLocStrSize;
+    Constant *SrcLocStr =
+        getOrCreateSrcLocStr(LocationDescription(Builder), SrcLocStrSize);
+    Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+
+    // @__kmpc_omp_task_alloc
+    Function *TaskAllocFn =
+        getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
+
+    // Arguments - `loc_ref` (Ident) and `gtid` (ThreadID)
+    // call.
+    Value *ThreadID = getOrCreateThreadID(Ident);
+
+    // Argument - `sizeof_kmp_task_t` (TaskSize)
+    // Tasksize refers to the size in bytes of kmp_task_t data structure
+    // including private vars accessed in task.
+    // TODO: add kmp_task_t_with_privates (privates)
+    Value *TaskSize = Builder.getInt64(
+        divideCeil(M.getDataLayout().getTypeSizeInBits(Task), 8));
+
+    // Argument - `sizeof_shareds` (SharedsSize)
+    // SharedsSize refers to the shareds array size in the kmp_task_t data
+    // structure.
+    Value *SharedsSize = Builder.getInt64(0);
+    if (HasShareds) {
+      AllocaInst *ArgStructAlloca =
+          dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
+      assert(ArgStructAlloca &&
+             "Unable to find the alloca instruction corresponding to arguments "
+             "for extracted function");
+      StructType *ArgStructType =
+          dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
+      assert(ArgStructType && "Unable to find struct type corresponding to "
+                              "arguments for extracted function");
+      SharedsSize =
+          Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
+    }
+
+    // Argument - `flags`
+    // Task is tied iff (Flags & 1) == 1.
+    // Task is untied iff (Flags & 1) == 0.
+    // Task is final iff (Flags & 2) == 2.
+    // Task is not final iff (Flags & 2) == 0.
+    // A target task is not final and is untied.
+    Value *Flags = Builder.getInt32(0);
+
+    // Emit the @__kmpc_omp_task_alloc runtime call
+    // The runtime call returns a pointer to an area where the task captured
+    // variables must be copied before the task is run (TaskData)
+    CallInst *TaskData = Builder.CreateCall(
+        TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
+                      /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
+                      /*task_func=*/ProxyFn});
+
+    if (HasShareds) {
+      Value *Shareds = StaleCI->getArgOperand(1);
+      Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
+      Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
+      Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
+                           SharedsSize);
+    }
+
+    Value *DepArray = emitDepArray(*this, Dependencies);
+
+    // ---------------------------------------------------------------
+    // V5.2 13.8 target construct
+    // If the nowait clause is present, execution of the target task
+    // may be deferred. If the nowait clause is not present, the target task is
+    // an included task.
+    // ---------------------------------------------------------------
+    // The above means that the lack of a nowait on the target construct
+    // translates to '#pragma omp task if(0)'
+    if (!HasNoWait) {
+      if (DepArray) {
+        Function *TaskWaitFn =
+            getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_wait_deps);
+        Builder.CreateCall(
+            TaskWaitFn,
+            {Ident, ThreadID, Builder.getInt32(Dependencies.size()), DepArray,
+             ConstantInt::get(Builder.getInt32Ty(), 0),
+             ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
+      }
+      // Included task.
+      Function *TaskBeginFn =
+          getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_begin_if0);
+      Function *TaskCompleteFn =
+          getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_complete_if0);
+      Builder.CreateCall(TaskBeginFn, {Ident, ThreadID, TaskData});
+      CallInst *CI = nullptr;
+      if (HasShareds)
+        CI = Builder.CreateCall(ProxyFn, {ThreadID, TaskData});
+      else
+        CI = Builder.CreateCall(ProxyFn, {ThreadID});
+      CI->setDebugLoc(StaleCI->getDebugLoc());
+      Builder.CreateCall(TaskCompleteFn, {Ident, ThreadID, TaskData});
+    } else if (DepArray) {
+      // HasNoWait - meaning the task may be deferred. Call
+      // __kmpc_omp_task_with_deps if there are dependencies,
+      // else call __kmpc_omp_task
+      Function *TaskFn =
+          getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_with_deps);
+      Builder.CreateCall(
+          TaskFn,
+          {Ident, ThreadID, TaskData, Builder.getInt32(Dependencies.size()),
+           DepArray, ConstantInt::get(Builder.getInt32Ty(), 0),
+           ConstantPointerNull::get(PointerType::getUnqual(M.getContext()))});
+    } else {
+      // Emit the @__kmpc_omp_task runtime call to spawn the task
+      Function *TaskFn = getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task);
+      Builder.CreateCall(TaskFn, {Ident, ThreadID, TaskData});
+    }
 
-static void emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-                           OpenMPIRBuilder::InsertPointTy AllocaIP,
-                           Function *OutlinedFn, Constant *OutlinedFnID,
-                           int32_t NumTeams, int32_t NumThreads,
-                           SmallVectorImpl<Value *> &Args,
-                           OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB) {
+    StaleCI->eraseFromParent();
+    while (!ToBeDeleted.empty()) {
+      ToBeDeleted.top()->eraseFromParent();
+      ToBeDeleted.pop();
+    }
----------------
Meinersbur wrote:

```suggestion
    llvm::for_each(ToBeDeleted, [](Instruction *I) { I->ereaseFromParent(); } );
```

https://github.com/llvm/llvm-project/pull/93977


More information about the cfe-commits mailing list