[llvm] [OMPIRBuilder] Added `createTeams` (PR #66807)

Kiran Chandramohan via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 19 14:58:40 PDT 2023


================
@@ -5735,6 +5735,140 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
   return Builder.saveIP();
 }
 
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
+                             BodyGenCallbackTy BodyGenCB) {
+  if (!updateToLocation(Loc))
+    return InsertPointTy();
+
+  uint32_t SrcLocStrSize;
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+  Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+  Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
+
+  // Outer allocation basicblock is the entry block of the current function.
+  BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
+  if (&OuterAllocaBB == Builder.GetInsertBlock()) {
+    BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
+    Builder.SetInsertPoint(BodyBB, BodyBB->begin());
+  }
+
+  // The current basic block is split into four basic blocks. After outlining,
+  // they will be mapped as follows:
+  // ```
+  // def current_fn() {
+  //   current_basic_block:
+  //     br label %teams.exit
+  //   teams.exit:
+  //     ; instructions after teams
+  // }
+  //
+  // def outlined_fn() {
+  //   teams.alloca:
+  //     br label %teams.body
+  //   teams.body:
+  //     ; instructions within teams body
+  // }
+  // ```
+  BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
+  BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
+  BasicBlock *AllocaBB =
+      splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
+
+  OutlineInfo OI;
+  OI.EntryBB = AllocaBB;
+  OI.ExitBB = ExitBB;
+  OI.OuterAllocaBB = &OuterAllocaBB;
+  OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
+    // The input IR here looks like the following-
+    // ```
+    // func @current_fn() {
+    //   outlined_fn(%args)
+    // }
+    // func @outlined_fn(%args) { ... }
+    // ```
+    //
+    // This is changed to the following-
+    //
+    // ```
+    // func @current_fn() {
+    //   runtime_call(..., wrapper_fn, ...)
+    // }
+    // func @wrapper_fn(..., %args) {
+    //   outlined_fn(%args)
+    // }
+    // func @outlined_fn(%args) { ... }
+    // ```
+
+    // The stale call instruction will be replaced with a new call instruction
+    // for runtime call with a wrapper function.
+
+    assert(OutlinedFn.getNumUses() == 1 &&
+           "there must be a single user for the outlined function");
+    CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+    assert(StaleCI && "Error while outlining - no CallInst user found for the "
+                      "outlined function.");
+    OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
+
+    // Create the wrapper function.
+    Builder.SetInsertPoint(StaleCI);
+    SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
+    for (auto &Arg : OutlinedFn.args()) {
+      WrapperArgTys.push_back(Arg.getType());
+    }
+    FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
+        (Twine(OutlinedFn.getName()) + ".teams").str(),
+        FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
+    Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
+    WrapperFunc->getArg(0)->setName("global_tid");
+    WrapperFunc->getArg(1)->setName("bound_tid");
+    if (WrapperFunc->arg_size() > 2) {
+      WrapperFunc->getArg(2)->setName("data");
+    }
+
+    // Emit the body of the wrapper function - just a call to outlined function
+    // and return statement.
+    BasicBlock *WrapperEntryBB =
+        BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
+    Builder.SetInsertPoint(WrapperEntryBB);
+    SmallVector<Value *> Args;
+    for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++) {
+      Args.push_back(WrapperFunc->getArg(ArgIndex));
+    }
+    Builder.CreateCall(&OutlinedFn, Args);
+    Builder.CreateRetVoid();
+
+    // Call to the runtime function for teams in the current function.
+    Builder.SetInsertPoint(StaleCI);
+    Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
+    for (Use &Arg : StaleCI->args()) {
+      Args.push_back(Arg);
+    }
+    Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
+                           omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
+                       Args);
+    StaleCI->eraseFromParent();
+
+    // Inlining the outlined teams function in the wrapper. This wrapper is the
+    // argument for the runtime call.
+    assert(OutlinedFn.getNumUses() == 1 &&
+           "More than one use for the outlined function found. Expected only "
+           "one use.");
+    OutlinedFn.addFnAttr(Attribute::AlwaysInline);
----------------
kiranchandramohan wrote:

This was done in 5811.

https://github.com/llvm/llvm-project/pull/66807


More information about the llvm-commits mailing list