[clang] [llvm] [mlir] [OMPIRBuilder] - Handle dependencies in `createTarget` (PR #93977)
Kareem Ergawy via llvm-commits
llvm-commits at lists.llvm.org
Wed Jun 5 02:13:22 PDT 2024
================
@@ -5212,6 +5273,78 @@ static Function *createOutlinedFunction(
return Func;
}
+// Create an entry point for a target task with the following.
+// It'll have the following signature
+// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
+// This function is called from emitTargetTask once the
+// code to launch the target kernel has been outlined already.
+static Function *emitProxyTaskFunction(OpenMPIRBuilder &OMPBuilder,
+ IRBuilderBase &Builder,
+ CallInst *StaleCI) {
+ Module &M = OMPBuilder.M;
+ // CalledFunction is the target launch function, i.e.
+ // the function that sets up kernel arguments and calls
+ // __tgt_target_kernel to launch the kernel on the device.
+ Function *CalledFunction = StaleCI->getCalledFunction();
+ OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
+ StaleCI->getIterator());
+ LLVMContext &Ctx = StaleCI->getParent()->getContext();
+ Type *ThreadIDTy = Type::getInt32Ty(Ctx);
+ Type *TaskPtrTy = OMPBuilder.TaskPtr;
+ Type *TaskTy = OMPBuilder.Task;
+ auto ProxyFnTy =
+ FunctionType::get(Builder.getVoidTy(), {ThreadIDTy, TaskPtrTy},
+ /* isVarArg */ false);
+ auto ProxyFn = Function::Create(ProxyFnTy, GlobalValue::InternalLinkage,
+ ".omp_target_task_proxy_func",
+ Builder.GetInsertBlock()->getModule());
+
+ BasicBlock *EntryBB =
+ BasicBlock::Create(Builder.getContext(), "entry", ProxyFn);
+ Builder.SetInsertPoint(EntryBB);
+
+ bool HasShareds = StaleCI->arg_size() > 1;
+ // TODO: This is a temporary assert to prove to ourselves that
+ // the outlined target launch function is always going to have
+ // atmost two arguments if there is any data shared between
+ // host and device.
+ assert((!HasShareds || (StaleCI->arg_size() == 2)) &&
+ "StaleCI with shareds should have exactly two arguments.");
+ if (HasShareds) {
+ AllocaInst *ArgStructAlloca =
+ dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
+ assert(ArgStructAlloca &&
+ "Unable to find the alloca instruction corresponding to arguments "
+ "for extracted function");
+ StructType *ArgStructType =
+ dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
+ LLVM_DEBUG(dbgs() << "ArgStructType = " << *ArgStructType << "\n");
----------------
ergawy wrote:
nit: This will be printed out-of-context and might be confusing. I think this can be removed before merging.
https://github.com/llvm/llvm-project/pull/93977
More information about the llvm-commits
mailing list