[llvm] [OpenMPIRBuilder] Added `createTeams` (PR #65767)

via llvm-commits llvm-commits at lists.llvm.org
Fri Sep 8 08:56:39 PDT 2023


https://github.com/shraiysh updated https://github.com/llvm/llvm-project/pull/65767:

>From 2a01a4999621f79e08ad6fef6544413c90c3012f Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Thu, 7 Sep 2023 15:45:52 -0500
Subject: [PATCH 1/2] [OpenMPIRBuilder] Added `createTeams`

This patch adds a generator for the teams construct. The generated IR looks like the following:

```
current_fn() {
  ...
  call @__kmpc_fork_teams(ptr @ident, i32 num_args, ptr @outlined_omp_teams, ...args)
  ...
}
outlined_omp_teams(ptr %global_tid, ptr %bound_tid, ...args) {
  ; teams body
}
```

It does this by first generating the body in the current function. Then we outline the
body in a temporary function. We then create the @outlined_omp_teams function and embed
the temporary outlined function in this function. We then emit the call to runtime
function.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  11 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 188 ++++++++++++++++--
 .../Frontend/OpenMPIRBuilderTest.cpp          |  73 +++++++
 3 files changed, 248 insertions(+), 24 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index ea1035f1907e492..d26ac60939031ec 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1159,8 +1159,8 @@ class OpenMPIRBuilder {
                                 InsertPointTy AllocaIP,
                                 BodyGenCallbackTy BodyGenCB);
 
-
-  using FileIdentifierInfoCallbackTy = std::function<std::tuple<std::string, uint64_t>()>;
+  using FileIdentifierInfoCallbackTy =
+      std::function<std::tuple<std::string, uint64_t>()>;
 
   /// Creates a unique info for a target entry when provided a filename and
   /// line number from.
@@ -2005,6 +2005,13 @@ class OpenMPIRBuilder {
   /// \param Loc The insert and source location description.
   void createTargetDeinit(const LocationDescription &Loc);
 
+  /// Generator for `#omp teams`
+  ///
+  /// \param Loc The location where the task construct was encountered.
+  /// \param BodyGenCB Callback that will generate the region code.
+  InsertPointTy createTeams(const LocationDescription &Loc,
+                            BodyGenCallbackTy BodyGenCB);
+
   ///}
 
 private:
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 2cfb36d11dcf898..0f2203f0c1ac84c 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -36,10 +36,12 @@
 #include "llvm/IR/LLVMContext.h"
 #include "llvm/IR/MDBuilder.h"
 #include "llvm/IR/Metadata.h"
+#include "llvm/IR/Module.h"
 #include "llvm/IR/PassManager.h"
 #include "llvm/IR/Value.h"
 #include "llvm/MC/TargetRegistry.h"
 #include "llvm/Support/CommandLine.h"
+#include "llvm/Support/Debug.h"
 #include "llvm/Support/ErrorHandling.h"
 #include "llvm/Support/FileSystem.h"
 #include "llvm/Target/TargetMachine.h"
@@ -390,9 +392,9 @@ void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
       if (Param) {
         if (auto AK = TargetLibraryInfo::getExtAttrForI32Param(T, HasSignExt))
           FnAS = FnAS.addAttribute(Ctx, AK);
-      } else
-        if (auto AK = TargetLibraryInfo::getExtAttrForI32Return(T, HasSignExt))
-          FnAS = FnAS.addAttribute(Ctx, AK);
+      } else if (auto AK =
+                     TargetLibraryInfo::getExtAttrForI32Return(T, HasSignExt))
+        FnAS = FnAS.addAttribute(Ctx, AK);
     } else {
       FnAS = FnAS.addAttributes(Ctx, AS);
     }
@@ -406,7 +408,7 @@ void OpenMPIRBuilder::addAttributes(omp::RuntimeFunction FnID, Function &Fn) {
 #define OMP_RTL_ATTRS(Enum, FnAttrSet, RetAttrSet, ArgAttrSets)                \
   case Enum:                                                                   \
     FnAttrs = FnAttrs.addAttributes(Ctx, FnAttrSet);                           \
-    addAttrSet(RetAttrs, RetAttrSet, /*Param*/false);                          \
+    addAttrSet(RetAttrs, RetAttrSet, /*Param*/ false);                         \
     for (size_t ArgNo = 0; ArgNo < ArgAttrSets.size(); ++ArgNo)                \
       addAttrSet(ArgAttrs[ArgNo], ArgAttrSets[ArgNo]);                         \
     Fn.setAttributes(AttributeList::get(Ctx, FnAttrs, RetAttrs, ArgAttrs));    \
@@ -4927,8 +4929,8 @@ void OpenMPIRBuilder::emitOffloadingArrays(
             static_cast<std::underlying_type_t<OpenMPOffloadMappingFlags>>(
                 CombinedInfo.Types[I] &
                 OpenMPOffloadMappingFlags::OMP_MAP_NON_CONTIG))
-          ConstSizes[I] = ConstantInt::get(Int64Ty,
-                                           CombinedInfo.NonContigInfo.Dims[I]);
+          ConstSizes[I] =
+              ConstantInt::get(Int64Ty, CombinedInfo.NonContigInfo.Dims[I]);
         else
           ConstSizes[I] = CI;
         continue;
@@ -4991,8 +4993,8 @@ void OpenMPIRBuilder::emitOffloadingArrays(
         createOffloadMapnames(CombinedInfo.Names, MapnamesName);
     Info.RTArgs.MapNamesArray = MapNamesArrayGbl;
   } else {
-    Info.RTArgs.MapNamesArray = Constant::getNullValue(
-        PointerType::getUnqual(Builder.getContext()));
+    Info.RTArgs.MapNamesArray =
+        Constant::getNullValue(PointerType::getUnqual(Builder.getContext()));
   }
 
   // If there's a present map type modifier, it must not be applied to the end
@@ -5017,10 +5019,10 @@ void OpenMPIRBuilder::emitOffloadingArrays(
   for (unsigned I = 0; I < Info.NumberOfPtrs; ++I) {
     Value *BPVal = CombinedInfo.BasePointers[I];
     Value *BP = Builder.CreateConstInBoundsGEP2_32(
-        ArrayType::get(PtrTy, Info.NumberOfPtrs),
-        Info.RTArgs.BasePointersArray, 0, I);
-    Builder.CreateAlignedStore(
-        BPVal, BP, M.getDataLayout().getPrefTypeAlign(PtrTy));
+        ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.BasePointersArray,
+        0, I);
+    Builder.CreateAlignedStore(BPVal, BP,
+                               M.getDataLayout().getPrefTypeAlign(PtrTy));
 
     if (Info.requiresDevicePointerInfo()) {
       if (CombinedInfo.DevicePointers[I] == DeviceInfoTy::Pointer) {
@@ -5039,21 +5041,21 @@ void OpenMPIRBuilder::emitOffloadingArrays(
 
     Value *PVal = CombinedInfo.Pointers[I];
     Value *P = Builder.CreateConstInBoundsGEP2_32(
-        ArrayType::get(PtrTy, Info.NumberOfPtrs),
-        Info.RTArgs.PointersArray, 0, I);
+        ArrayType::get(PtrTy, Info.NumberOfPtrs), Info.RTArgs.PointersArray, 0,
+        I);
     // TODO: Check alignment correct.
-    Builder.CreateAlignedStore(
-        PVal, P, M.getDataLayout().getPrefTypeAlign(PtrTy));
+    Builder.CreateAlignedStore(PVal, P,
+                               M.getDataLayout().getPrefTypeAlign(PtrTy));
 
     if (RuntimeSizes.test(I)) {
       Value *S = Builder.CreateConstInBoundsGEP2_32(
           ArrayType::get(Int64Ty, Info.NumberOfPtrs), Info.RTArgs.SizesArray,
           /*Idx0=*/0,
           /*Idx1=*/I);
-      Builder.CreateAlignedStore(
-          Builder.CreateIntCast(CombinedInfo.Sizes[I], Int64Ty,
-                                /*isSigned=*/true),
-          S, M.getDataLayout().getPrefTypeAlign(PtrTy));
+      Builder.CreateAlignedStore(Builder.CreateIntCast(CombinedInfo.Sizes[I],
+                                                       Int64Ty,
+                                                       /*isSigned=*/true),
+                                 S, M.getDataLayout().getPrefTypeAlign(PtrTy));
     }
     // Fill up the mapper array.
     unsigned IndexSize = M.getDataLayout().getIndexSizeInBits(0);
@@ -5655,8 +5657,8 @@ GlobalVariable *
 OpenMPIRBuilder::createOffloadMapnames(SmallVectorImpl<llvm::Constant *> &Names,
                                        std::string VarName) {
   llvm::Constant *MapNamesArrayInit = llvm::ConstantArray::get(
-      llvm::ArrayType::get(
-          llvm::PointerType::getUnqual(M.getContext()), Names.size()),
+      llvm::ArrayType::get(llvm::PointerType::getUnqual(M.getContext()),
+                           Names.size()),
       Names);
   auto *MapNamesArrayGlobal = new llvm::GlobalVariable(
       M, MapNamesArrayInit->getType(),
@@ -6106,6 +6108,148 @@ void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
   }
 }
 
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
+                             BodyGenCallbackTy BodyGenCB) {
+  if (!updateToLocation(Loc)) {
+    return Loc.IP;
+  }
+
+  uint32_t SrcLocStrSize;
+  Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
+  Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+
+  // Splitting a basic block expects a terminator. Hence, creating an
+  // unreachable instruction, which will be deleted later.
+  UnreachableInst *UI = Builder.CreateUnreachable();
+  BasicBlock *CurrentBasicBlock = Builder.GetInsertBlock();
+
+  // The current basic block is split into four basic blocks. After outlining,
+  // they will be mapped as follows:
+  // ```
+  // def current_fn() {
+  //   current_basic_block:
+  //     br label %teams.exit
+  //   teams.exit:
+  //     ; instructions after task
+  // }
+  // def outlined_fn() {
+  //   teams.alloca:
+  //     br label %teams.body
+  //   teams.body:
+  //     ; instructions within teams body
+  // }
+  // ```
+  BasicBlock *AllocaBB = CurrentBasicBlock->splitBasicBlock(UI, "teams.alloca");
+  BasicBlock *BodyBB = AllocaBB->splitBasicBlock(UI, "teams.body");
+  BasicBlock *ExitBB = BodyBB->splitBasicBlock(UI, "teams.exit");
+
+  UI->eraseFromParent();
+
+  // Generate the body of teams.
+  InsertPointTy CodeGenIP(BodyBB, BodyBB->begin());
+  InsertPointTy AllocaIP(AllocaBB, AllocaBB->begin());
+  BodyGenCB(AllocaIP, CodeGenIP);
+
+  OutlineInfo OI;
+  OI.EntryBB = AllocaBB;
+  OI.ExitBB = ExitBB;
+  OI.PostOutlineCB = [this, Ident](Function &OutlinedFn) {
+    // The input IR here looks like the following-
+    // ```
+    // func @current_fn() {
+    //   outlined_fn(%args)
+    // }
+    // func @outlined_fn(%args) {
+    //   ; teams body
+    // }
+    // ```
+    //
+    // This is changed to the following-
+    //
+    // ```
+    // func @current_fn() {
+    //   runtime_call(..., wrapper_fn, ...)
+    // }
+    // func @wrapper_fn(..., %args) {
+    //   ; teams body
+    // }
+    // ```
+
+    // The outlined function has different inputs than what is expected from it.
+    // So, a wrapper function with expected signature is created and the
+    // required arguments are passed to the outlined function. The stale call
+    // instruction in current function will be replaced with a new call
+    // instruction for runtime call with the wrapper function. The outlined
+    // function is then inlined in the wrapper function and the call from the
+    // current function is removed.
+
+    assert(OutlinedFn.getNumUses() == 1 &&
+           "there must be a single user for the outlined function");
+    CallInst *StaleCI = cast<CallInst>(OutlinedFn.user_back());
+    assert(StaleCI && "Error while outlining - no CallInst user found for the "
+                      "outlined function.");
+    OutlinedFn.addFnAttr(Attribute::AttrKind::AlwaysInline);
+
+    // Create the wrapper function.
+    Builder.SetInsertPoint(StaleCI);
+    SmallVector<Type *> WrapperArgTys{Builder.getPtrTy(), Builder.getPtrTy()};
+    for (auto &Arg : OutlinedFn.args()) {
+      WrapperArgTys.push_back(Arg.getType());
+    }
+    FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
+        "outlined_omp_teams",
+        FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
+    Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
+    WrapperFunc->getArg(0)->setName("global_tid");
+    WrapperFunc->getArg(1)->setName("bound_tid");
+    WrapperFunc->getArg(2)->setName("data");
+
+    // Emit the body of the wrapper function - just a call to outlined function
+    // and return statement.
+    BasicBlock *WrapperEntryBB =
+        BasicBlock::Create(M.getContext(), "entrybb", WrapperFunc);
+    Builder.SetInsertPoint(WrapperEntryBB);
+    SmallVector<Value *> Args;
+    for (size_t ArgIndex = 2; ArgIndex < WrapperFunc->arg_size(); ArgIndex++) {
+      Args.push_back(WrapperFunc->getArg(ArgIndex));
+    }
+    CallInst *OutlinedFnCall = Builder.CreateCall(&OutlinedFn, Args);
+    Builder.CreateRetVoid();
+
+    // Call to the runtime function for teams in the current function.
+    Builder.SetInsertPoint(StaleCI);
+    Args = {Ident, Builder.getInt32(StaleCI->arg_size()), WrapperFunc};
+    for (Use &Arg : StaleCI->args()) {
+      Args.push_back(Arg);
+    }
+    Builder.CreateCall(getOrCreateRuntimeFunctionPtr(
+                           omp::RuntimeFunction::OMPRTL___kmpc_fork_teams),
+                       Args);
+    StaleCI->eraseFromParent();
+
+    // Inlining the outlined teams function in the wrapper. This wrapper is the
+    // argument for the runtime call.
+    assert(OutlinedFn.getNumUses() == 1 &&
+           "More than one use for the outlined function found. Expected only "
+           "one use.");
+    InlineFunctionInfo IFI;
+    InlineResult IR = InlineFunction(*OutlinedFnCall, IFI);
+    LLVM_DEBUG(if (!IR.isSuccess()) {
+      dbgs() << "Attempt to merge the outlined function in the wrapper failed: "
+             << IR.getFailureReason() << "\n";
+    });
+    assert(IR.isSuccess() && "Inlining outlined omp teams failed");
+    OutlinedFn.eraseFromParent();
+  };
+
+  addOutlineInfo(std::move(OI));
+
+  Builder.SetInsertPoint(ExitBB);
+
+  return Builder.saveIP();
+}
+
 bool OffloadEntriesInfoManager::empty() const {
   return OffloadEntriesTargetRegion.empty() &&
          OffloadEntriesDeviceGlobalVar.empty();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 5870457956b5433..b189559b3430461 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6033,4 +6033,77 @@ TEST_F(OpenMPIRBuilderTest, createGPUOffloadEntry) {
   EXPECT_TRUE(Fn->hasFnAttribute(Attribute::MustProgress));
 }
 
+TEST_F(OpenMPIRBuilderTest, createTeams) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  AllocaInst *ValPtr32 = Builder.CreateAlloca(Builder.getInt32Ty());
+  AllocaInst *ValPtr128 = Builder.CreateAlloca(Builder.getInt128Ty());
+  Value *Val128 =
+      Builder.CreateLoad(Builder.getInt128Ty(), ValPtr128, "bodygen.load");
+
+  auto BodyGenCB = [&](InsertPointTy AllocaIP, InsertPointTy CodeGenIP) {
+    Builder.restoreIP(AllocaIP);
+    AllocaInst *Local128 = Builder.CreateAlloca(Builder.getInt128Ty(), nullptr,
+                                                "bodygen.alloca128");
+
+    Builder.restoreIP(CodeGenIP);
+    // Loading and storing captured pointer and values
+    Builder.CreateStore(Val128, Local128);
+    Value *Val32 = Builder.CreateLoad(ValPtr32->getAllocatedType(), ValPtr32,
+                                      "bodygen.load32");
+
+    LoadInst *PrivLoad128 = Builder.CreateLoad(
+        Local128->getAllocatedType(), Local128, "bodygen.local.load128");
+    Value *Cmp = Builder.CreateICmpNE(
+        Val32, Builder.CreateTrunc(PrivLoad128, Val32->getType()));
+    Instruction *ThenTerm, *ElseTerm;
+    SplitBlockAndInsertIfThenElse(Cmp, CodeGenIP.getBlock()->getTerminator(),
+                                  &ThenTerm, &ElseTerm);
+  };
+
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  Builder.restoreIP(OMPBuilder.createTeams(Builder, BodyGenCB));
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  CallInst *TeamsForkCall = dyn_cast<CallInst>(
+      OMPBuilder.getOrCreateRuntimeFunctionPtr(OMPRTL___kmpc_fork_teams)
+          ->user_back());
+
+  // Verify the Ident argument
+  GlobalVariable *Ident = cast<GlobalVariable>(TeamsForkCall->getArgOperand(0));
+  ASSERT_NE(Ident, nullptr);
+  EXPECT_TRUE(Ident->hasInitializer());
+  Constant *Initializer = Ident->getInitializer();
+  GlobalVariable *SrcStrGlob =
+      cast<GlobalVariable>(Initializer->getOperand(4)->stripPointerCasts());
+  ASSERT_NE(SrcStrGlob, nullptr);
+  ConstantDataArray *SrcSrc =
+      dyn_cast<ConstantDataArray>(SrcStrGlob->getInitializer());
+  ASSERT_NE(SrcSrc, nullptr);
+
+  // Verify the outlined function signature.
+  Function *OutlinedFn =
+      dyn_cast<Function>(TeamsForkCall->getArgOperand(2)->stripPointerCasts());
+  ASSERT_NE(OutlinedFn, nullptr);
+  EXPECT_FALSE(OutlinedFn->isDeclaration());
+  EXPECT_TRUE(OutlinedFn->arg_size() >= 3);
+  EXPECT_EQ(OutlinedFn->getArg(0)->getType(), Builder.getPtrTy()); // global_tid
+  EXPECT_EQ(OutlinedFn->getArg(1)->getType(), Builder.getPtrTy()); // bound_tid
+  EXPECT_EQ(OutlinedFn->getArg(2)->getType(),
+            Builder.getPtrTy()); // captured args
+
+  // Check for TruncInst and ICmpInst in the outlined function.
+  EXPECT_TRUE(any_of(instructions(OutlinedFn),
+                     [](Instruction &inst) { return isa<TruncInst>(&inst); }));
+  EXPECT_TRUE(any_of(instructions(OutlinedFn),
+                     [](Instruction &inst) { return isa<ICmpInst>(&inst); }));
+}
+
 } // namespace

>From df97eaa4a4308165226651e02b8f560a3b0dae55 Mon Sep 17 00:00:00 2001
From: Shraiysh Vaishay <shraiysh.vaishay at amd.com>
Date: Fri, 8 Sep 2023 10:56:17 -0500
Subject: [PATCH 2/2] Addressed comments

---
 llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h | 2 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp        | 5 ++---
 2 files changed, 3 insertions(+), 4 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index d26ac60939031ec..037e0a5662be5bb 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2007,7 +2007,7 @@ class OpenMPIRBuilder {
 
   /// Generator for `#omp teams`
   ///
-  /// \param Loc The location where the task construct was encountered.
+  /// \param Loc The location where the teams construct was encountered.
   /// \param BodyGenCB Callback that will generate the region code.
   InsertPointTy createTeams(const LocationDescription &Loc,
                             BodyGenCallbackTy BodyGenCB);
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 0f2203f0c1ac84c..a873d95366bc9a6 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6111,9 +6111,8 @@ void OpenMPIRBuilder::loadOffloadInfoMetadata(Module &M) {
 OpenMPIRBuilder::InsertPointTy
 OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
                              BodyGenCallbackTy BodyGenCB) {
-  if (!updateToLocation(Loc)) {
+  if (!updateToLocation(Loc))
     return Loc.IP;
-  }
 
   uint32_t SrcLocStrSize;
   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
@@ -6198,7 +6197,7 @@ OpenMPIRBuilder::createTeams(const LocationDescription &Loc,
       WrapperArgTys.push_back(Arg.getType());
     }
     FunctionCallee WrapperFuncVal = M.getOrInsertFunction(
-        "outlined_omp_teams",
+        (Twine(OutlinedFn.getName()) + ".teams").str(),
         FunctionType::get(Builder.getVoidTy(), WrapperArgTys, false));
     Function *WrapperFunc = dyn_cast<Function>(WrapperFuncVal.getCallee());
     WrapperFunc->getArg(0)->setName("global_tid");



More information about the llvm-commits mailing list