[llvm] [OMPIRBuilder] Added `createTeams` (PR #66807)
Kiran Chandramohan via llvm-commits
llvm-commits at lists.llvm.org
Tue Sep 19 14:58:40 PDT 2023
================
@@ -5735,6 +5735,140 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
return Builder.saveIP();
}
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
+ BodyGenCallbackTy BodyGenCB) {
+ if (!updateToLocation(Loc))
+ return InsertPointTy();
+
+ uint32_t SrcLocStrSize;
+ Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+ Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+ Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
+
+ // Outer allocation basicblock is the entry block of the current function.
+ BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
+ if (&OuterAllocaBB == Builder.GetInsertBlock()) {
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
+ Builder.SetInsertPoint(BodyBB, BodyBB->begin());
+ }
+
+ // The current basic block is split into four basic blocks. After outlining,
+ // they will be mapped as follows:
+ // ```
+ // def current_fn() {
+ // current_basic_block:
+ // br label %teams.exit
+ // teams.exit:
+ // ; instructions after teams
+ // }
+ //
+ // def outlined_fn() {
+ // teams.alloca:
+ // br label %teams.body
+ // teams.body:
+ // ; instructions within teams body
+ // }
+ // ```
+ BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
+ BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
+ BasicBlock *AllocaBB =
+ splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
+
+ OutlineInfo OI;
+ OI.EntryBB = AllocaBB;
+ OI.ExitBB = ExitBB;
+ OI.OuterAllocaBB = &OuterAllocaBB;
+ OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
+ // The input IR here looks like the following-
+ // ```
+ // func @current_fn() {
+ // outlined_fn(%args)
+ // }
+ // func @outlined_fn(%args) { ... }
+ // ```
+ //
+ // This is changed to the following-
+ //
+ // ```
+ // func @current_fn() {
+ // runtime_call(..., wrapper_fn, ...)
+ // }
+ // func @wrapper_fn(..., %args) {
+ // outlined_fn(%args)
+ // }
+ // func @outlined_fn(%args) { ... }
+ // ```
+
+ // The stale call instruction will be replaced with a new call instruction
+ // for runtime call with a wrapper function.
+
+ assert(OutlinedFn.getNumUses() == 1 &&
+ "there must be a single user for the outlined function");
+ CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+ assert(StaleCI && "Error while outlining - no CallInst user found for the "
+ "outlined function.");
+ OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
+
+ // Create the wrapper function.
+ Builder.SetInsertPoint(StaleCI);
+ SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
+ for (auto &Arg : OutlinedFn.args()) {
+ WrapperArgTys.push_back(Arg.getType());
+ }
+ FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
+ (Twine(OutlinedFn.getName()) + ".teams").str(),
+ FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
+ Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
+ WrapperFunc->getArg(0)->setName("global_tid");
+ WrapperFunc->getArg(1)->setName("bound_tid");
+ if (WrapperFunc->arg_size() > 2) {
+ WrapperFunc->getArg(2)->setName("data");
+ }
+
+ // Emit the body of the wrapper function - just a call to outlined function
+ // and return statement.
+ BasicBlock *WrapperEntryBB =
+ BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
+ Builder.SetInsertPoint(WrapperEntryBB);
+ SmallVector<Value *> Args;
+ for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++) {
+ Args.push_back(WrapperFunc->getArg(ArgIndex));
+ }
+ Builder.CreateCall(&OutlinedFn, Args);
+ Builder.CreateRetVoid();
+
+ // Call to the runtime function for teams in the current function.
+ Builder.SetInsertPoint(StaleCI);
+ Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
+ for (Use &Arg : StaleCI->args()) {
+ Args.push_back(Arg);
+ }
+ Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
+ omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
+ Args);
+ StaleCI->eraseFromParent();
+
+ // Inlining the outlined teams function in the wrapper. This wrapper is the
+ // argument for the runtime call.
+ assert(OutlinedFn.getNumUses() == 1 &&
+ "More than one use for the outlined function found. Expected only "
+ "one use.");
+ OutlinedFn.addFnAttr(Attribute::AlwaysInline);
----------------
kiranchandramohan wrote:
This was done in 5811.
https://github.com/llvm/llvm-project/pull/66807
More information about the llvm-commits
mailing list