[clang] [llvm] [mlir] [OMPIRBuilder] - Handle dependencies in `createTarget` (PR #93977)

Kareem Ergawy via cfe-commits cfe-commits at lists.llvm.org
Wed Jun 5 02:13:22 PDT 2024


================
@@ -5212,6 +5273,78 @@ static Function *createOutlinedFunction(
   return Func;
 }
 
+// Create an entry point for a target task with the following.
+// It'll have the following signature
+// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
+// This function is called from emitTargetTask once the
+// code to launch the target kernel has been outlined already.
+static Function *emitProxyTaskFunction(OpenMPIRBuilder &OMPBuilder,
+                                       IRBuilderBase &Builder,
+                                       CallInst *StaleCI) {
+  Module &M = OMPBuilder.M;
+  // CalledFunction is the target launch function, i.e.
+  // the function that sets up kernel arguments and calls
+  // __tgt_target_kernel to launch the kernel on the device.
+  Function *CalledFunction = StaleCI->getCalledFunction();
+  OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
+                                    StaleCI->getIterator());
+  LLVMContext &Ctx = StaleCI->getParent()->getContext();
+  Type *ThreadIDTy = Type::getInt32Ty(Ctx);
+  Type *TaskPtrTy = OMPBuilder.TaskPtr;
+  Type *TaskTy = OMPBuilder.Task;
+  auto ProxyFnTy =
+      FunctionType::get(Builder.getVoidTy(), {ThreadIDTy, TaskPtrTy},
+                        /* isVarArg */ false);
+  auto ProxyFn = Function::Create(ProxyFnTy, GlobalValue::InternalLinkage,
+                                  ".omp_target_task_proxy_func",
+                                  Builder.GetInsertBlock()->getModule());
+
+  BasicBlock *EntryBB =
+      BasicBlock::Create(Builder.getContext(), "entry", ProxyFn);
+  Builder.SetInsertPoint(EntryBB);
+
+  bool HasShareds = StaleCI->arg_size() > 1;
+  // TODO: This is a temporary assert to prove to ourselves that
+  // the outlined target launch function is always going to have
+  // atmost two arguments if there is any data shared between
+  // host and device.
+  assert((!HasShareds || (StaleCI->arg_size() == 2)) &&
+         "StaleCI with shareds should have exactly two arguments.");
+  if (HasShareds) {
+    AllocaInst *ArgStructAlloca =
+        dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
+    assert(ArgStructAlloca &&
+           "Unable to find the alloca instruction corresponding to arguments "
+           "for extracted function");
+    StructType *ArgStructType =
+        dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
+    LLVM_DEBUG(dbgs() << "ArgStructType = " << *ArgStructType << "\n");
----------------
ergawy wrote:

nit: This will be printed out-of-context and might be confusing. I think this can be removed before merging.

https://github.com/llvm/llvm-project/pull/93977


More information about the cfe-commits mailing list