[flang-commits] [flang] [OpenMPIRBuilder] Remove wrapper function in `createTask`, `createTeams` (PR #67723)

Johannes Doerfert via flang-commits flang-commits at lists.llvm.org
Tue Oct 24 12:49:29 PDT 2023


================
@@ -5771,84 +5779,63 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
   BasicBlock *AllocaBB =
       splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
 
+  // Generate the body of teams.
+  InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
+  InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
+  BodyGenCB(AllocaIP, CodeGenIP);
+
   OutlineInfo OI;
   OI.EntryBB = AllocaBB;
   OI.ExitBB = ExitBB;
   OI.OuterAllocaBB = &OuterAllocaBB;
-  OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
-    // The input IR here looks like the following-
-    // ```
-    // func @current_fn() {
-    //   outlined_fn(%args)
-    // }
-    // func @outlined_fn(%args) { ... }
-    // ```
-    //
-    // This is changed to the following-
-    //
-    // ```
-    // func @current_fn() {
-    //   runtime_call(..., wrapper_fn, ...)
-    // }
-    // func @wrapper_fn(..., %args) {
-    //   outlined_fn(%args)
-    // }
-    // func @outlined_fn(%args) { ... }
-    // ```
 
+  // Insert fake values for global tid and bound tid.
+  std::stack<Instruction *> ToBeDeleted;
+  InsertPointTy OuterAllocaIP(&OuterAllocaBB, OuterAllocaBB.begin());
+  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+      Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "gid", true));
+  OI.ExcludeArgsFromAggregate.push_back(createFakeIntVal(
+      Builder, OuterAllocaIP, ToBeDeleted, AllocaIP, "tid", true));
+
+  OI.PostOutlineCB = [this, Ident, ToBeDeleted](Function &OutlinedFn) mutable {
     // The stale call instruction will be replaced with a new call instruction
-    // for runtime call with a wrapper function.
+    // for runtime call with the outlined function.
 
     assert(OutlinedFn.getNumUses() == 1 &&
            "there must be a single user for the outlined function");
     CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+    ToBeDeleted.push(StaleCI);
+
+    assert((OutlinedFn.arg_size() == 2 || OutlinedFn.arg_size() == 3) &&
+           "Outlined function must have two or three arguments only");
+
+    bool HasShared = OutlinedFn.arg_size() == 3;
 
-    // Create the wrapper function.
-    SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
-    for (auto &Arg : OutlinedFn.args())
-      WrapperArgTys.push_back(Arg.getType());
-    FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
-        (Twine(OutlinedFn.getName()) + ".teams").str(),
-        FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
-    Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
-    WrapperFunc->getArg(0)->setName("global_tid");
-    WrapperFunc->getArg(1)->setName("bound_tid");
-    if (WrapperFunc->arg_size() > 2)
-      WrapperFunc->getArg(2)->setName("data");
-
-    // Emit the body of the wrapper function - just a call to outlined function
-    // and return statement.
-    BasicBlock *WrapperEntryBB =
-        BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
-    Builder.SetInsertPoint(WrapperEntryBB);
-    SmallVector<Value *> Args;
-    for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++)
-      Args.push_back(WrapperFunc->getArg(ArgIndex));
-    Builder.CreateCall(&OutlinedFn, Args);
-    Builder.CreateRetVoid();
-
-    OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
+    OutlinedFn.getArg(0)->setName("global.tid.ptr");
+    OutlinedFn.getArg(1)->setName("bound.tid.ptr");
+    if (HasShared)
+      OutlinedFn.getArg(2)->setName("data");
 
     // Call to the runtime function for teams in the current function.
     assert(StaleCI && "Error while outlining - no CallInst user found for the "
                       "outlined function.");
     Builder.SetInsertPoint(StaleCI);
-    Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
-    for (Use &Arg : StaleCI->args())
-      Args.push_back(Arg);
+    SmallVector<Value *> Args = {
+        Ident, Builder.getInt32(StaleCI->arg_size() - 2), &OutlinedFn};
+    if (HasShared)
+      Args.push_back(StaleCI->getArgOperand(2));
     Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
                            omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
                        Args);
-    StaleCI->eraseFromParent();
+
+    while (!ToBeDeleted.empty()) {
+      ToBeDeleted.top()->eraseFromParent();
+      ToBeDeleted.pop();
----------------
jdoerfert wrote:

same

https://github.com/llvm/llvm-project/pull/67723


More information about the flang-commits mailing list