[llvm] 8d17875 - [OMPIRBuilder] Added `createTeams` (#66807)

via llvm-commits llvm-commits at lists.llvm.org
Sun Sep 24 14:23:48 PDT 2023


Author: Shraiysh
Date: 2023-09-24T16:23:43-05:00
New Revision: 8d17875acba650ef7b3fa516fd09633a50dd0427

URL: https://github.com/llvm/llvm-project/commit/8d17875acba650ef7b3fa516fd09633a50dd0427
DIFF: https://github.com/llvm/llvm-project/commit/8d17875acba650ef7b3fa516fd09633a50dd0427.diff

LOG: [OMPIRBuilder] Added `createTeams` (#66807)

This patch adds basic support for `omp teams` to the OpenMPIRBuilder.

The outlined function after code extraction is called from a wrapper
function with appropriate arguments. This wrapper function is passed to
the runtime calls.

This approach is different from the Clang approach - clang directly
emits the runtime call to the outlined function. The outlining utility
(OutlineInfo) simply outlines the code and generates a function call to
the outlined function. After the function has been generated by the
outlining utility, there is no easy way to alter the function arguments
without meddling with the outlining itself. Hence the wrapper function
approach is taken.

Added: 
    

Modified: 
    llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
    llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
    llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp

Removed: 
    


################################################################################
diff  --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 523a0718e1ffb53..1699ed3aeab7661 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1889,6 +1889,13 @@ class OpenMPIRBuilder {
                               BodyGenCallbackTy BodyGenCB,
                               FinalizeCallbackTy FiniCB);
 
+  /// Generator for `#omp teams`
+  ///
+  /// \param Loc The location where the teams construct was encountered.
+  /// \param BodyGenCB Callback that will generate the region code.
+  InsertPointTy createTeams(const LocationDescription &Loc,
+                            BodyGenCallbackTy BodyGenCB);
+
   /// Generate conditional branch and relevant BasicBlocks through which private
   /// threads copy the 'copyin' variables from Master copy to threadprivate
   /// copies.

diff  --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 1ace7d5b97ffc96..8a5f4cb1b85ff9c 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5735,6 +5735,129 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
   return Builder.saveIP();
 }
 
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
+                             BodyGenCallbackTy BodyGenCB) {
+  if (!updateToLocation(Loc))
+    return InsertPointTy();
+
+  uint32_t SrcLocStrSize;
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+  Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+  Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
+
+  // Outer allocation basicblock is the entry block of the current function.
+  BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
+  if (&OuterAllocaBB == Builder.GetInsertBlock()) {
+    BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
+    Builder.SetInsertPoint(BodyBB, BodyBB->begin());
+  }
+
+  // The current basic block is split into four basic blocks. After outlining,
+  // they will be mapped as follows:
+  // ```
+  // def current_fn() {
+  //   current_basic_block:
+  //     br label %teams.exit
+  //   teams.exit:
+  //     ; instructions after teams
+  // }
+  //
+  // def outlined_fn() {
+  //   teams.alloca:
+  //     br label %teams.body
+  //   teams.body:
+  //     ; instructions within teams body
+  // }
+  // ```
+  BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
+  BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
+  BasicBlock *AllocaBB =
+      splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
+
+  OutlineInfo OI;
+  OI.EntryBB = AllocaBB;
+  OI.ExitBB = ExitBB;
+  OI.OuterAllocaBB = &OuterAllocaBB;
+  OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
+    // The input IR here looks like the following-
+    // ```
+    // func @current_fn() {
+    //   outlined_fn(%args)
+    // }
+    // func @outlined_fn(%args) { ... }
+    // ```
+    //
+    // This is changed to the following-
+    //
+    // ```
+    // func @current_fn() {
+    //   runtime_call(..., wrapper_fn, ...)
+    // }
+    // func @wrapper_fn(..., %args) {
+    //   outlined_fn(%args)
+    // }
+    // func @outlined_fn(%args) { ... }
+    // ```
+
+    // The stale call instruction will be replaced with a new call instruction
+    // for runtime call with a wrapper function.
+
+    assert(OutlinedFn.getNumUses() == 1 &&
+           "there must be a single user for the outlined function");
+    CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+
+    // Create the wrapper function.
+    SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
+    for (auto &Arg : OutlinedFn.args())
+      WrapperArgTys.push_back(Arg.getType());
+    FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
+        (Twine(OutlinedFn.getName()) + ".teams").str(),
+        FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
+    Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
+    WrapperFunc->getArg(0)->setName("global_tid");
+    WrapperFunc->getArg(1)->setName("bound_tid");
+    if (WrapperFunc->arg_size() > 2)
+      WrapperFunc->getArg(2)->setName("data");
+
+    // Emit the body of the wrapper function - just a call to outlined function
+    // and return statement.
+    BasicBlock *WrapperEntryBB =
+        BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
+    Builder.SetInsertPoint(WrapperEntryBB);
+    SmallVector<Value *> Args;
+    for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++)
+      Args.push_back(WrapperFunc->getArg(ArgIndex));
+    Builder.CreateCall(&OutlinedFn, Args);
+    Builder.CreateRetVoid();
+
+    OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
+
+    // Call to the runtime function for teams in the current function.
+    assert(StaleCI && "Error while outlining - no CallInst user found for the "
+                      "outlined function.");
+    Builder.SetInsertPoint(StaleCI);
+    Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
+    for (Use &Arg : StaleCI->args())
+      Args.push_back(Arg);
+    Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
+                           omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
+                       Args);
+    StaleCI->eraseFromParent();
+  };
+
+  addOutlineInfo(std::move(OI));
+
+  // Generate the body of teams.
+  InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
+  InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
+  BodyGenCB(AllocaIP, CodeGenIP);
+
+  Builder.SetInsertPoint(ExitBB, ExitBB->begin());
+
+  return Builder.saveIP();
+}
+
 GlobalVariable *
 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
                                        std::string VarName) {

diff  --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 2026824416f3e3c..fd524f6067ee0ea 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4001,6 +4001,87 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicCompareCapture) {
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateTeams) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
+  AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
+  Value *Val128 = Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "load");
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(AllocaIP);
+    AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
+                                                "bodygen.alloca128");
+
+    Builder.restoreIP(CodeGenIP);
+    // Loading and storing captured pointer and values
+    Builder.CreateStore(Val128, Local128);
+    Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
+                                      "bodygen.load32");
+
+    LoadInst *PrivLoad128 = Builder.CreateLoad(
+        Local128->getAllocatedType(), Local128, "bodygen.local.load128");
+    Value *Cmp = Builder.CreateICmpNE(
+        Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType()));
+    Instruction *ThenTerm, *ElseTerm;
+    SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
+                                  &ThenTerm, &ElseTerm);
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB));
+
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  CallInst *TeamsForkCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)
+          ->user_back());
+
+  // Verify the Ident argument
+  GlobalVariable *Ident = cast<GlobalVariable>(TeamsForkCall->getArgOperand(0));
+  ASSERT_NE(Ident, nullptr);
+  EXPECT_TRUE(Ident->hasInitializer());
+  Constant *Initializer = Ident->getInitializer();
+  GlobalVariable *SrcStrGlob =
+      cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
+  ASSERT_NE(SrcStrGlob, nullptr);
+  ConstantDataArray *SrcSrc =
+      dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
+  ASSERT_NE(SrcSrc, nullptr);
+
+  // Verify the outlined function signature.
+  Function *WrapperFn =
+      dyn_cast<Function>(TeamsForkCall->getArgOperand(2)->stripPointerCasts());
+  ASSERT_NE(WrapperFn, nullptr);
+  EXPECT_FALSE(WrapperFn->isDeclaration());
+  EXPECT_TRUE(WrapperFn->arg_size() >= 3);
+  EXPECT_EQ(WrapperFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid
+  EXPECT_EQ(WrapperFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid
+  EXPECT_EQ(WrapperFn->getArg(2)->getType(),
+            Builder.getPtrTy()); // captured args
+
+  // Check for TruncInst and ICmpInst in the outlined function.
+  inst_range Instructions = instructions(WrapperFn);
+  auto OutlinedFnInst = find_if(
+      Instructions, [](Instruction &Inst) { return isa<CallInst>(&Inst); });
+  ASSERT_NE(OutlinedFnInst, Instructions.end());
+  CallInst *OutlinedFnCI = dyn_cast<CallInst>(&*OutlinedFnInst);
+  ASSERT_NE(OutlinedFnCI, nullptr);
+  Function *OutlinedFn = OutlinedFnCI->getCalledFunction();
+
+  EXPECT_TRUE(any_of(instructions(OutlinedFn),
+                     [](Instruction &inst) { return isa<TruncInst>(&inst); }));
+  EXPECT_TRUE(any_of(instructions(OutlinedFn),
+                     [](Instruction &inst) { return isa<ICmpInst>(&inst); }));
+}
+
 /// Returns the single instruction of InstTy type in BB that uses the value V.
 /// If there is more than one such instruction, returns null.
 template <typename InstTy>


        


More information about the llvm-commits mailing list