[flang-commits] [flang] [OpenMPIRBuilder] Remove wrapper function in `createTask`, `createTeams` (PR #67723)
Johannes Doerfert via flang-commits
flang-commits at lists.llvm.org
Tue Oct 24 12:49:29 PDT 2023
================
@@ -5771,84 +5779,63 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
BasicBlock *AllocaBB =
splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
+ // Generate the body of teams.
+ InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
+ InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
+ BodyGenCB(AllocaIP, CodeGenIP);
+
OutlineInfo OI;
OI.EntryBB = AllocaBB;
OI.ExitBB = ExitBB;
OI.OuterAllocaBB = &OuterAllocaBB;
- OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
- // The input IR here looks like the following-
- // ```
- // func @current_fn() {
- // outlined_fn(%args)
- // }
- // func @outlined_fn(%args) { ... }
- // ```
- //
- // This is changed to the following-
- //
- // ```
- // func @current_fn() {
- // runtime_call(..., wrapper_fn, ...)
- // }
- // func @wrapper_fn(..., %args) {
- // outlined_fn(%args)
- // }
- // func @outlined_fn(%args) { ... }
- // ```
+ // Insert fake values for global tid and bound tid.
+ std::stack<Instruction *> ToBeDeleted;
+ InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
+ OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+ Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
+ OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+ Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
+
+ OI.PostOutlineCB = [this, Ident, ToBeDeleted](Function &OutlinedFn) mutable {
// The stale call instruction will be replaced with a new call instruction
- // for runtime call with a wrapper function.
+ // for runtime call with the outlined function.
assert(OutlinedFn.getNumUses() == 1 &&
"there must be a single user for the outlined function");
CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+ ToBeDeleted.push(StaleCI);
+
+ assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
+ "Outlined function must have two or three arguments only");
+
+ bool HasShared = OutlinedFn.arg_size() == 3;
- // Create the wrapper function.
- SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
- for (auto &Arg : OutlinedFn.args())
- WrapperArgTys.push_back(Arg.getType());
- FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
- (Twine(OutlinedFn.getName()) + ".teams").str(),
- FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
- Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
- WrapperFunc->getArg(0)->setName("global_tid");
- WrapperFunc->getArg(1)->setName("bound_tid");
- if (WrapperFunc->arg_size() > 2)
- WrapperFunc->getArg(2)->setName("data");
-
- // Emit the body of the wrapper function - just a call to outlined function
- // and return statement.
- BasicBlock *WrapperEntryBB =
- BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
- Builder.SetInsertPoint(WrapperEntryBB);
- SmallVector<Value *> Args;
- for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++)
- Args.push_back(WrapperFunc->getArg(ArgIndex));
- Builder.CreateCall(&OutlinedFn, Args);
- Builder.CreateRetVoid();
-
- OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
+ OutlinedFn.getArg(0)->setName("global.tid.ptr");
+ OutlinedFn.getArg(1)->setName("bound.tid.ptr");
+ if (HasShared)
+ OutlinedFn.getArg(2)->setName("data");
// Call to the runtime function for teams in the current function.
assert(StaleCI && "Error while outlining - no CallInst user found for the "
"outlined function.");
Builder.SetInsertPoint(StaleCI);
- Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
- for (Use &Arg : StaleCI->args())
- Args.push_back(Arg);
+ SmallVector<Value *> Args = {
+ Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn};
+ if (HasShared)
+ Args.push_back(StaleCI->getArgOperand(2));
Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
Args);
- StaleCI->eraseFromParent();
+
+ while (!ToBeDeleted.empty()) {
+ ToBeDeleted.top()->eraseFromParent();
+ ToBeDeleted.pop();
----------------
jdoerfert wrote:
same
https://github.com/llvm/llvm-project/pull/67723
More information about the flang-commits
mailing list