[clang] [llvm] [mlir] [OMPIRBuilder] - Handle dependencies in `createTarget` (PR #93977)
Pranav Bhandarkar via cfe-commits
cfe-commits at lists.llvm.org
Wed Jun 5 14:39:24 PDT 2024
================
@@ -5212,6 +5273,78 @@ static Function *createOutlinedFunction(
return Func;
}
+// Create an entry point for a target task with the following.
+// It'll have the following signature
+// void @.omp_target_task_proxy_func(i32 %thread.id, ptr %task)
+// This function is called from emitTargetTask once the
+// code to launch the target kernel has been outlined already.
+static Function *emitProxyTaskFunction(OpenMPIRBuilder &OMPBuilder,
+ IRBuilderBase &Builder,
+ CallInst *StaleCI) {
+ Module &M = OMPBuilder.M;
+ // CalledFunction is the target launch function, i.e.
+ // the function that sets up kernel arguments and calls
+ // __tgt_target_kernel to launch the kernel on the device.
+ Function *CalledFunction = StaleCI->getCalledFunction();
+ OpenMPIRBuilder::InsertPointTy IP(StaleCI->getParent(),
+ StaleCI->getIterator());
+ LLVMContext &Ctx = StaleCI->getParent()->getContext();
+ Type *ThreadIDTy = Type::getInt32Ty(Ctx);
+ Type *TaskPtrTy = OMPBuilder.TaskPtr;
+ Type *TaskTy = OMPBuilder.Task;
+ auto ProxyFnTy =
+ FunctionType::get(Builder.getVoidTy(), {ThreadIDTy, TaskPtrTy},
+ /* isVarArg */ false);
+ auto ProxyFn = Function::Create(ProxyFnTy, GlobalValue::InternalLinkage,
+ ".omp_target_task_proxy_func",
+ Builder.GetInsertBlock()->getModule());
+
+ BasicBlock *EntryBB =
+ BasicBlock::Create(Builder.getContext(), "entry", ProxyFn);
+ Builder.SetInsertPoint(EntryBB);
+
+ bool HasShareds = StaleCI->arg_size() > 1;
+ // TODO: This is a temporary assert to prove to ourselves that
+ // the outlined target launch function is always going to have
+ // atmost two arguments if there is any data shared between
+ // host and device.
+ assert((!HasShareds || (StaleCI->arg_size() == 2)) &&
+ "StaleCI with shareds should have exactly two arguments.");
+ if (HasShareds) {
+ AllocaInst *ArgStructAlloca =
+ dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
+ assert(ArgStructAlloca &&
+ "Unable to find the alloca instruction corresponding to arguments "
+ "for extracted function");
+ StructType *ArgStructType =
+ dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
+ LLVM_DEBUG(dbgs() << "ArgStructType = " << *ArgStructType << "\n");
+
+ AllocaInst *NewArgStructAlloca =
+ Builder.CreateAlloca(ArgStructType, nullptr, "structArg");
+ Value *TaskT = ProxyFn->getArg(1);
+ Value *ThreadId = ProxyFn->getArg(0);
+ LLVM_DEBUG(dbgs() << "TaskT = " << *TaskT << "\n");
+ Value *SharedsSize =
+ Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
+
+ Value *Shareds = Builder.CreateStructGEP(TaskTy, TaskT, 0);
+ LoadInst *LoadShared =
+ Builder.CreateLoad(PointerType::getUnqual(Ctx), Shareds);
+
+ // TODO: Are these alignment values correct?
+ Builder.CreateMemCpy(
+ NewArgStructAlloca,
+ NewArgStructAlloca->getPointerAlignment(M.getDataLayout()), LoadShared,
+ LoadShared->getPointerAlignment(M.getDataLayout()), SharedsSize);
+
+ Builder.CreateCall(CalledFunction, {ThreadId, NewArgStructAlloca});
+ }
+ ProxyFn->getArg(0)->setName("thread.id");
+ ProxyFn->getArg(1)->setName("task");
----------------
bhandarkar-pranav wrote:
Sure, will do.
https://github.com/llvm/llvm-project/pull/93977
More information about the cfe-commits
mailing list