[Mlir-commits] [clang] [llvm] [mlir] [OMPIRBuilder] - Handle dependencies in `createTarget` (PR #93977)

Michael Kruse llvmlistbot at llvm.org
Fri Jun 7 11:05:41 PDT 2024


================
@@ -5212,6 +5273,78 @@ static Function *createOutlinedFunction(
   return Func;
 }
 
+// Create an entry point for a target task with the following.
+// It'll have the following signature
+// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
+// This function is called from emitTargetTask once the
+// code to launch the target kernel has been outlined already.
+static Function *emitProxyTaskFunction(OpenMPIRBuilder &OMPBuilder,
+                                       IRBuilderBase &Builder,
+                                       CallInst *StaleCI) {
+  Module &M = OMPBuilder.M;
+  // CalledFunction is the target launch function, i.e.
+  // the function that sets up kernel arguments and calls
+  // __tgt_target_kernel to launch the kernel on the device.
+  Function *CalledFunction = StaleCI->getCalledFunction();
+  OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
+                                    StaleCI->getIterator());
+  LLVMContext &Ctx = StaleCI->getParent()->getContext();
+  Type *ThreadIDTy = Type::getInt32Ty(Ctx);
+  Type *TaskPtrTy = OMPBuilder.TaskPtr;
+  Type *TaskTy = OMPBuilder.Task;
+  auto ProxyFnTy =
+      FunctionType::get(Builder.getVoidTy(), {ThreadIDTy, TaskPtrTy},
+                        /* isVarArg */ false);
+  auto ProxyFn = Function::Create(ProxyFnTy, GlobalValue::InternalLinkage,
+                                  ".omp_target_task_proxy_func",
+                                  Builder.GetInsertBlock()->getModule());
+
+  BasicBlock *EntryBB =
+      BasicBlock::Create(Builder.getContext(), "entry", ProxyFn);
+  Builder.SetInsertPoint(EntryBB);
+
+  bool HasShareds = StaleCI->arg_size() > 1;
+  // TODO: This is a temporary assert to prove to ourselves that
+  // the outlined target launch function is always going to have
+  // atmost two arguments if there is any data shared between
+  // host and device.
+  assert((!HasShareds || (StaleCI->arg_size() == 2)) &&
+         "StaleCI with shareds should have exactly two arguments.");
+  if (HasShareds) {
+    AllocaInst *ArgStructAlloca =
+        dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
+    assert(ArgStructAlloca &&
+           "Unable to find the alloca instruction corresponding to arguments "
+           "for extracted function");
+    StructType *ArgStructType =
+        dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
+    LLVM_DEBUG(dbgs() << "ArgStructType = " << *ArgStructType << "\n");
+
+    AllocaInst *NewArgStructAlloca =
+        Builder.CreateAlloca(ArgStructType, nullptr, "structArg");
+    Value *TaskT = ProxyFn->getArg(1);
+    Value *ThreadId = ProxyFn->getArg(0);
+    LLVM_DEBUG(dbgs() << "TaskT = " << *TaskT << "\n");
+    Value *SharedsSize =
+        Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
+
+    Value *Shareds = Builder.CreateStructGEP(TaskTy, TaskT, 0);
+    LoadInst *LoadShared =
+        Builder.CreateLoad(PointerType::getUnqual(Ctx), Shareds);
+
+    // TODO: Are these alignment values correct?
----------------
Meinersbur wrote:

I think `NewArgStructAlloca->getAlign()` shold be sufficient. If the alloca/load doesn't have the align set explicitly, `memcpy` should apply pointer alignment itself.

https://github.com/llvm/llvm-project/pull/93977


More information about the Mlir-commits mailing list