[llvm] [mlir] [Flang][OpenMP][Taskloop] Translation support for taskloop construct (PR #166903)
Kaviya Rajendiran via llvm-commits
llvm-commits at lists.llvm.org
Sun Nov 16 10:48:58 PST 2025
================
@@ -1933,6 +1933,205 @@ static Value *emitTaskDependencies(
return DepArray;
}
+OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTaskloop(
+ const LocationDescription &Loc, InsertPointTy AllocaIP,
+ BodyGenCallbackTy BodyGenCB,
+ llvm::function_ref<llvm::Expected<llvm::CanonicalLoopInfo *>()> loopInfo,
+ Value *LBVal, Value *UBVal, Value *StepVal, bool Tied) {
+
+ if (!updateToLocation(Loc))
+ return InsertPointTy();
+
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+ Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+
+ BasicBlock *TaskloopExitBB =
+ splitBB(Builder, /*CreateBranch=*/true, "taskloop.exit");
+ BasicBlock *TaskloopBodyBB =
+ splitBB(Builder, /*CreateBranch=*/true, "taskloop.body");
+ BasicBlock *TaskloopAllocaBB =
+ splitBB(Builder, /*CreateBranch=*/true, "taskloop.alloca");
+
+ InsertPointTy TaskloopAllocaIP =
+ InsertPointTy(TaskloopAllocaBB, TaskloopAllocaBB->begin());
+ InsertPointTy TaskloopBodyIP =
+ InsertPointTy(TaskloopBodyBB, TaskloopBodyBB->begin());
+
+ if (Error Err = BodyGenCB(TaskloopAllocaIP, TaskloopBodyIP))
+ return Err;
+
+ llvm::Expected<llvm::CanonicalLoopInfo *> result = loopInfo();
+ if (!result) {
+ return result.takeError();
+ }
+
+ llvm::CanonicalLoopInfo *CLI = result.get();
+ OutlineInfo OI;
+ OI.EntryBB = TaskloopAllocaBB;
+ OI.OuterAllocaBB = AllocaIP.getBlock();
+ OI.ExitBB = TaskloopExitBB;
+
+ // Add the thread ID argument.
+ SmallVector<Instruction *, 4> ToBeDeleted;
+ // dummy instruction to be used as a fake argument
+ OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+ Builder, AllocaIP, ToBeDeleted, TaskloopAllocaIP, "global.tid", false));
+
+ OI.PostOutlineCB = [this, Ident, LBVal, UBVal, StepVal, Tied,
+ TaskloopAllocaBB, CLI, Loc,
+ ToBeDeleted](Function &OutlinedFn) mutable {
+ // Replace the Stale CI by appropriate RTL function call.
+ assert(OutlinedFn.hasOneUse() &&
+ "there must be a single user for the outlined function");
+ CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+
+ // HasShareds is true if any variables are captured in the outlined region,
+ // false otherwise.
+ bool HasShareds = StaleCI->arg_size() > 1;
+ Builder.SetInsertPoint(StaleCI);
+
+ // Gather the arguments for emitting the runtime call for
+ // @__kmpc_omp_task_alloc
+ Function *TaskAllocFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_omp_task_alloc);
+
+ Value *ThreadID = getOrCreateThreadID(Ident);
+
+ // Emit runtime call for @__kmpc_taskgroup
+ Function *TaskgroupFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskgroup);
+ Builder.CreateCall(TaskgroupFn, {Ident, ThreadID});
+
+ // The flags are set to 1 if the task is tied, 0 otherwise.
+ Value *Flags = Builder.getInt32(Tied);
+
+ Value *TaskSize = Builder.getInt64(
+ divideCeil(M.getDataLayout().getTypeSizeInBits(Taskloop), 8));
+
+ Value *SharedsSize = Builder.getInt64(0);
+ if (HasShareds) {
+ AllocaInst *ArgStructAlloca =
+ dyn_cast<AllocaInst>(StaleCI->getArgOperand(1));
+ assert(ArgStructAlloca &&
+ "Unable to find the alloca instruction corresponding to arguments "
+ "for extracted function");
+ StructType *ArgStructType =
+ dyn_cast<StructType>(ArgStructAlloca->getAllocatedType());
+ assert(ArgStructType && "Unable to find struct type corresponding to "
+ "arguments for extracted function");
+ SharedsSize =
+ Builder.getInt64(M.getDataLayout().getTypeStoreSize(ArgStructType));
+ }
+
+ // Emit the @__kmpc_omp_task_alloc runtime call
+ // The runtime call returns a pointer to an area where the task captured
+ // variables must be copied before the task is run (TaskData)
+ CallInst *TaskData = Builder.CreateCall(
+ TaskAllocFn, {/*loc_ref=*/Ident, /*gtid=*/ThreadID, /*flags=*/Flags,
+ /*sizeof_task=*/TaskSize, /*sizeof_shared=*/SharedsSize,
+ /*task_func=*/&OutlinedFn});
+
+ // Get the pointer to loop lb, ub, step from task ptr
+ // and set up the lowerbound,upperbound and step values
+ llvm::Value *lb =
+ Builder.CreateStructGEP(OpenMPIRBuilder::Taskloop, TaskData, 5);
+ // Value *LbVal_ext = Builder.CreateSExt(LBVal, Builder.getInt64Ty());
+ Builder.CreateStore(LBVal, lb);
+
+ llvm::Value *ub =
+ Builder.CreateStructGEP(OpenMPIRBuilder::Taskloop, TaskData, 6);
+ Builder.CreateStore(UBVal, ub);
+
+ llvm::Value *step =
+ Builder.CreateStructGEP(OpenMPIRBuilder::Taskloop, TaskData, 7);
+ Value *Step_ext = Builder.CreateSExt(StepVal, Builder.getInt64Ty());
+ Builder.CreateStore(Step_ext, step);
+ llvm::Value *loadstep = Builder.CreateLoad(Builder.getInt64Ty(), step);
+
+ if (HasShareds) {
+ Value *Shareds = StaleCI->getArgOperand(1);
+ Align Alignment = TaskData->getPointerAlignment(M.getDataLayout());
+ Value *TaskShareds = Builder.CreateLoad(VoidPtr, TaskData);
+ Builder.CreateMemCpy(TaskShareds, Alignment, Shareds, Alignment,
+ SharedsSize);
+ }
+
+ // set up the arguments for emitting kmpc_taskloop runtime call
+ // setting default values for ifval, nogroup, sched, grainsize, task_dup
+ Value *IfVal = Builder.getInt32(1);
+ Value *NoGroup = Builder.getInt32(1);
+ Value *Sched = Builder.getInt32(0);
+ Value *GrainSize = Builder.getInt64(0);
+ Value *TaskDup = Constant::getNullValue(Builder.getPtrTy());
+
+ Value *Args[] = {Ident, ThreadID, TaskData, IfVal, lb, ub,
+ loadstep, NoGroup, Sched, GrainSize, TaskDup};
+
+ // taskloop runtime call
+ Function *TaskloopFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_taskloop);
+ Builder.CreateCall(TaskloopFn, Args);
+
+ // Emit the @__kmpc_end_taskgroup runtime call to end the taskgroup
+ Function *EndTaskgroupFn =
+ getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_end_taskgroup);
+ Builder.CreateCall(EndTaskgroupFn, {Ident, ThreadID});
+
+ StaleCI->eraseFromParent();
+
+ Builder.SetInsertPoint(TaskloopAllocaBB, TaskloopAllocaBB->begin());
+
+ if (HasShareds) {
+ LoadInst *Shareds = Builder.CreateLoad(VoidPtr, OutlinedFn.getArg(1));
+ OutlinedFn.getArg(1)->replaceUsesWithIf(
+ Shareds, [Shareds](Use &U) { return U.getUser() != Shareds; });
+ }
+
+ Value *IV = CLI->getIndVar();
+ Type *IVTy = IV->getType();
+ Constant *One = ConstantInt::get(IVTy, 1);
+
+ Value *task_lb = Builder.CreateStructGEP(OpenMPIRBuilder::Taskloop,
+ OutlinedFn.getArg(1), 5, "gep_lb");
+ Value *LowerBound = Builder.CreateLoad(IVTy, task_lb, "lb");
+
+ Value *task_ub = Builder.CreateStructGEP(OpenMPIRBuilder::Taskloop,
+ OutlinedFn.getArg(1), 6, "gep_ub");
+ Value *UpperBound = Builder.CreateLoad(IVTy, task_ub, "ub");
+
+ Builder.SetInsertPoint(CLI->getPreheader()->getTerminator());
+
+ Value *TripCountMinusOne = Builder.CreateSub(UpperBound, LowerBound);
+ Value *TripCount = Builder.CreateAdd(TripCountMinusOne, One, "trip_cnt");
+ // set the trip count in the CLI
+ CLI->setTripCount(TripCount);
+
+ Builder.SetInsertPoint(CLI->getBody(),
+ CLI->getBody()->getFirstInsertionPt());
+
+ llvm::BasicBlock *Body = CLI->getBody();
+ for (llvm::Instruction &I : *Body) {
+ if (auto *Add = llvm::dyn_cast<llvm::BinaryOperator>(&I)) {
+ if (Add->getOpcode() == llvm::Instruction::Add) {
+ if (llvm::isa<llvm::BinaryOperator>(Add->getOperand(0))) {
+ // update the starting index of the loop
+ Add->setOperand(1, LowerBound);
+ }
+ }
+ }
----------------
kaviya2510 wrote:
Yes, I agree with your comments that it might match other add instruction. The reason behind doing this change is that the taskloop construct divides the loop iterations into chunks and each chunk is executed as an explicit task. The loop bounds (lower bound, upper bound and step) for these chunks are computed by the runtime function `__kmpc_taskloop(...)` so we need to update the loop nest with the bounds returned by the runtime.
The current loop-nest translation sets the global loop bounds. This change ensures that the loop bounds are adjusted according to the values returned by the runtime.
I explored several alternative approaches to update the loop-bounds based on runtime, but none of them worked. Also the loop-nest translation is done at this stage where it returns the runtime loop bound values.
Could you share your thoughts, if you have better suggestion for handling this scenario?
https://github.com/llvm/llvm-project/pull/166903
More information about the llvm-commits
mailing list