[llvm] [OMPIRBuilder] Added `createTeams` (PR #66807)

via llvm-commits llvm-commits at lists.llvm.org
Tue Sep 19 22:32:02 PDT 2023


https://github.com/shraiysh updated https://github.com/llvm/llvm-project/pull/66807

>From da2ebe2f8d164caf506c1fecadc849db194fd474 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Tue, 19 Sep 2023 14:57:26 -0500
Subject: [PATCH 1/2] [OMPIRBuilder] Added `createTeams`

This patch adds basic support for `omp teams` to the OpenMPIRBuilder.

The outlined function after code extraction is called from a wrapper function
with appropriate arguments. This wrapper function is passed to the runtime
calls.

This approach is different from the Clang approach - clang directly emits the
runtime call to the outlined function. The outlining utility (OutlineInfo)
simply outlines the code and generates a function call to the outlined function.
After the function has been generated by the outlining utility, there is no easy
way to alter the function arguments without meddling with the outlining itself.
Hence the wrapper function approach is taken.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |   7 +
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 134 ++++++++++++++++++
 .../Frontend/OpenMPIRBuilderTest.cpp          |  81 +++++++++++
 3 files changed, 222 insertions(+)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 523a0718e1ffb53..1699ed3aeab7661 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1889,6 +1889,13 @@ class OpenMPIRBuilder {
                               BodyGenCallbackTy BodyGenCB,
                               FinalizeCallbackTy FiniCB);
 
+  /// Generator for `#omp teams`
+  ///
+  /// \param Loc The location where the teams construct was encountered.
+  /// \param BodyGenCB Callback that will generate the region code.
+  InsertPointTy createTeams(const LocationDescription &Loc,
+                            BodyGenCallbackTy BodyGenCB);
+
   /// Generate conditional branch and relevant BasicBlocks through which private
   /// threads copy the 'copyin' variables from Master copy to threadprivate
   /// copies.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 1ace7d5b97ffc96..4ec5bcaf11a7129 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5735,6 +5735,140 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createAtomicCompare(
   return Builder.saveIP();
 }
 
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
+                             BodyGenCallbackTy BodyGenCB) {
+  if (!updateToLocation(Loc))
+    return InsertPointTy();
+
+  uint32_t SrcLocStrSize;
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+  Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+  Function *CurrentFunction = Builder.GetInsertBlock()->getParent();
+
+  // Outer allocation basicblock is the entry block of the current function.
+  BasicBlock &OuterAllocaBB = CurrentFunction->getEntryBlock();
+  if (&OuterAllocaBB == Builder.GetInsertBlock()) {
+    BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.entry");
+    Builder.SetInsertPoint(BodyBB, BodyBB->begin());
+  }
+
+  // The current basic block is split into four basic blocks. After outlining,
+  // they will be mapped as follows:
+  // ```
+  // def current_fn() {
+  //   current_basic_block:
+  //     br label %teams.exit
+  //   teams.exit:
+  //     ; instructions after teams
+  // }
+  //
+  // def outlined_fn() {
+  //   teams.alloca:
+  //     br label %teams.body
+  //   teams.body:
+  //     ; instructions within teams body
+  // }
+  // ```
+  BasicBlock *ExitBB = splitBB(Builder, /*CreateBranch=*/true, "teams.exit");
+  BasicBlock *BodyBB = splitBB(Builder, /*CreateBranch=*/true, "teams.body");
+  BasicBlock *AllocaBB =
+      splitBB(Builder, /*CreateBranch=*/true, "teams.alloca");
+
+  OutlineInfo OI;
+  OI.EntryBB = AllocaBB;
+  OI.ExitBB = ExitBB;
+  OI.OuterAllocaBB = &OuterAllocaBB;
+  OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
+    // The input IR here looks like the following-
+    // ```
+    // func @current_fn() {
+    //   outlined_fn(%args)
+    // }
+    // func @outlined_fn(%args) { ... }
+    // ```
+    //
+    // This is changed to the following-
+    //
+    // ```
+    // func @current_fn() {
+    //   runtime_call(..., wrapper_fn, ...)
+    // }
+    // func @wrapper_fn(..., %args) {
+    //   outlined_fn(%args)
+    // }
+    // func @outlined_fn(%args) { ... }
+    // ```
+
+    // The stale call instruction will be replaced with a new call instruction
+    // for runtime call with a wrapper function.
+
+    assert(OutlinedFn.getNumUses() == 1 &&
+           "there must be a single user for the outlined function");
+    CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+    assert(StaleCI && "Error while outlining - no CallInst user found for the "
+                      "outlined function.");
+    OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
+
+    // Create the wrapper function.
+    Builder.SetInsertPoint(StaleCI);
+    SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
+    for (auto &Arg : OutlinedFn.args()) {
+      WrapperArgTys.push_back(Arg.getType());
+    }
+    FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
+        (Twine(OutlinedFn.getName()) + ".teams").str(),
+        FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
+    Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
+    WrapperFunc->getArg(0)->setName("global_tid");
+    WrapperFunc->getArg(1)->setName("bound_tid");
+    if (WrapperFunc->arg_size() > 2) {
+      WrapperFunc->getArg(2)->setName("data");
+    }
+
+    // Emit the body of the wrapper function - just a call to outlined function
+    // and return statement.
+    BasicBlock *WrapperEntryBB =
+        BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
+    Builder.SetInsertPoint(WrapperEntryBB);
+    SmallVector<Value *> Args;
+    for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++) {
+      Args.push_back(WrapperFunc->getArg(ArgIndex));
+    }
+    Builder.CreateCall(&OutlinedFn, Args);
+    Builder.CreateRetVoid();
+
+    // Call to the runtime function for teams in the current function.
+    Builder.SetInsertPoint(StaleCI);
+    Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
+    for (Use &Arg : StaleCI->args()) {
+      Args.push_back(Arg);
+    }
+    Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
+                           omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
+                       Args);
+    StaleCI->eraseFromParent();
+
+    // Inlining the outlined teams function in the wrapper. This wrapper is the
+    // argument for the runtime call.
+    assert(OutlinedFn.getNumUses() == 1 &&
+           "More than one use for the outlined function found. Expected only "
+           "one use.");
+    OutlinedFn.addFnAttr(Attribute::AlwaysInline);
+  };
+
+  addOutlineInfo(std::move(OI));
+
+  // Generate the body of teams.
+  InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
+  InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
+  BodyGenCB(AllocaIP, CodeGenIP);
+
+  Builder.SetInsertPoint(ExitBB, ExitBB->begin());
+
+  return Builder.saveIP();
+}
+
 GlobalVariable *
 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
                                        std::string VarName) {
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 2026824416f3e3c..fd524f6067ee0ea 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -4001,6 +4001,87 @@ TEST_F(OpenMPIRBuilderTest, OMPAtomicCompareCapture) {
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
+TEST_F(OpenMPIRBuilderTest, CreateTeams) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
+  AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
+  Value *Val128 = Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "load");
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(AllocaIP);
+    AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
+                                                "bodygen.alloca128");
+
+    Builder.restoreIP(CodeGenIP);
+    // Loading and storing captured pointer and values
+    Builder.CreateStore(Val128, Local128);
+    Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
+                                      "bodygen.load32");
+
+    LoadInst *PrivLoad128 = Builder.CreateLoad(
+        Local128->getAllocatedType(), Local128, "bodygen.local.load128");
+    Value *Cmp = Builder.CreateICmpNE(
+        Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType()));
+    Instruction *ThenTerm, *ElseTerm;
+    SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
+                                  &ThenTerm, &ElseTerm);
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB));
+
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  CallInst *TeamsForkCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)
+          ->user_back());
+
+  // Verify the Ident argument
+  GlobalVariable *Ident = cast<GlobalVariable>(TeamsForkCall->getArgOperand(0));
+  ASSERT_NE(Ident, nullptr);
+  EXPECT_TRUE(Ident->hasInitializer());
+  Constant *Initializer = Ident->getInitializer();
+  GlobalVariable *SrcStrGlob =
+      cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
+  ASSERT_NE(SrcStrGlob, nullptr);
+  ConstantDataArray *SrcSrc =
+      dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
+  ASSERT_NE(SrcSrc, nullptr);
+
+  // Verify the outlined function signature.
+  Function *WrapperFn =
+      dyn_cast<Function>(TeamsForkCall->getArgOperand(2)->stripPointerCasts());
+  ASSERT_NE(WrapperFn, nullptr);
+  EXPECT_FALSE(WrapperFn->isDeclaration());
+  EXPECT_TRUE(WrapperFn->arg_size() >= 3);
+  EXPECT_EQ(WrapperFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid
+  EXPECT_EQ(WrapperFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid
+  EXPECT_EQ(WrapperFn->getArg(2)->getType(),
+            Builder.getPtrTy()); // captured args
+
+  // Check for TruncInst and ICmpInst in the outlined function.
+  inst_range Instructions = instructions(WrapperFn);
+  auto OutlinedFnInst = find_if(
+      Instructions, [](Instruction &Inst) { return isa<CallInst>(&Inst); });
+  ASSERT_NE(OutlinedFnInst, Instructions.end());
+  CallInst *OutlinedFnCI = dyn_cast<CallInst>(&*OutlinedFnInst);
+  ASSERT_NE(OutlinedFnCI, nullptr);
+  Function *OutlinedFn = OutlinedFnCI->getCalledFunction();
+
+  EXPECT_TRUE(any_of(instructions(OutlinedFn),
+                     [](Instruction &inst) { return isa<TruncInst>(&inst); }));
+  EXPECT_TRUE(any_of(instructions(OutlinedFn),
+                     [](Instruction &inst) { return isa<ICmpInst>(&inst); }));
+}
+
 /// Returns the single instruction of InstTy type in BB that uses the value V.
 /// If there is more than one such instruction, returns null.
 template <typename InstTy>

>From b24dcdc5d5bd82b4768c3cf37ed9ae8f4699ad97 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Wed, 20 Sep 2023 00:31:15 -0500
Subject: [PATCH 2/2] Address comments

---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp | 34 ++++++++---------------
 1 file changed, 11 insertions(+), 23 deletions(-)

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 4ec5bcaf11a7129..f88995c942bce99 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -5803,28 +5803,18 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
     // The stale call instruction will be replaced with a new call instruction
     // for runtime call with a wrapper function.
 
-    assert(OutlinedFn.getNumUses() == 1 &&
-           "there must be a single user for the outlined function");
-    CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
-    assert(StaleCI && "Error while outlining - no CallInst user found for the "
-                      "outlined function.");
-    OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
-
     // Create the wrapper function.
-    Builder.SetInsertPoint(StaleCI);
     SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
-    for (auto &Arg : OutlinedFn.args()) {
+    for (auto &Arg : OutlinedFn.args())
       WrapperArgTys.push_back(Arg.getType());
-    }
     FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
         (Twine(OutlinedFn.getName()) + ".teams").str(),
         FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
     Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
     WrapperFunc->getArg(0)->setName("global_tid");
     WrapperFunc->getArg(1)->setName("bound_tid");
-    if (WrapperFunc->arg_size() > 2) {
+    if (WrapperFunc->arg_size() > 2)
       WrapperFunc->getArg(2)->setName("data");
-    }
 
     // Emit the body of the wrapper function - just a call to outlined function
     // and return statement.
@@ -5832,29 +5822,27 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
         BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
     Builder.SetInsertPoint(WrapperEntryBB);
     SmallVector<Value *> Args;
-    for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++) {
+    for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++)
       Args.push_back(WrapperFunc->getArg(ArgIndex));
-    }
     Builder.CreateCall(&OutlinedFn, Args);
     Builder.CreateRetVoid();
 
+    OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
+
     // Call to the runtime function for teams in the current function.
+    assert(OutlinedFn.getNumUses() == 1 &&
+           "there must be a single user for the outlined function");
+    CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+    assert(StaleCI && "Error while outlining - no CallInst user found for the "
+                      "outlined function.");
     Builder.SetInsertPoint(StaleCI);
     Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
-    for (Use &Arg : StaleCI->args()) {
+    for (Use &Arg : StaleCI->args())
       Args.push_back(Arg);
-    }
     Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
                            omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
                        Args);
     StaleCI->eraseFromParent();
-
-    // Inlining the outlined teams function in the wrapper. This wrapper is the
-    // argument for the runtime call.
-    assert(OutlinedFn.getNumUses() == 1 &&
-           "More than one use for the outlined function found. Expected only "
-           "one use.");
-    OutlinedFn.addFnAttr(Attribute::AlwaysInline);
   };
 
   addOutlineInfo(std::move(OI));



More information about the llvm-commits mailing list