[llvm] [OpenMPIRBuilder] Add support for target workshare loops (PR #73360)

Dominik Adamski via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 30 02:42:09 PST 2023


https://github.com/DominikAdamski updated https://github.com/llvm/llvm-project/pull/73360

>From 37015706a5e33bcb421926545a0e7f0e4a61ca28 Mon Sep 17 00:00:00 2001
From: Dominik Adamski <dominik.adamski at amd.com>
Date: Thu, 28 Sep 2023 14:18:36 -0400
Subject: [PATCH 1/3] [OpenMPIRBuilder] Add support for target workshare loops

The workshare loop for target region uses the new OpenMP device
runtime. The code generation scheme for the new device runtime
is presented below:

Input code:
workshare-loop {
  loop-body
}

Output code:
helper function:
function-loop-body(counter, loop-body-args) {
  loop-body
}
workshare-loop is replaced by the proper device runtime call:
call __kmpc_new_worksharing_rtl(function-loop-body, loop-body-args,
                                loop-tripcount, ...)
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  37 ++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 246 +++++++++++++++++-
 .../Frontend/OpenMPIRBuilderTest.cpp          |  67 +++++
 3 files changed, 348 insertions(+), 2 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 334eaf01a59c9ce..e1a9214c5f26598 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -439,6 +439,16 @@ class OffloadEntriesInfoManager {
 /// Each OpenMP directive has a corresponding public generator method.
 class OpenMPIRBuilder {
 public:
+  /// A type of worksharing loop construct
+  enum class WorksharingLoopType {
+    // Worksharing `for`-loop
+    ForStaticLoop,
+    // Worksharing `distrbute`-loop
+    DistributeStaticLoop,
+    // Worksharing `distrbute parallel for`-loop
+    DistributeForStaticLoop
+  };
+
   /// Create a new OpenMPIRBuilder operating on the given module \p M. This will
   /// not have an effect on \p M (see initialize)
   OpenMPIRBuilder(Module &M)
@@ -900,6 +910,28 @@ class OpenMPIRBuilder {
                               omp::OpenMPOffloadMappingFlags MemberOfFlag);
 
 private:
+  /// Modifies the canonical loop to be a statically-scheduled workshare loop
+  /// which is executed on the device
+  ///
+  /// This takes a \p LoopInfo representing a canonical loop, such as the one
+  /// created by \p createCanonicalLoop and emits additional instructions to
+  /// turn it into a workshare loop. In particular, it calls to an OpenMP
+  /// runtime function in the preheader to call OpenMP device rtl function
+  /// which handles worksharing of loop body interations.
+  ///
+  /// \param DL       Debug location for instructions added for the
+  ///                 workshare-loop construct itself.
+  /// \param CLI      A descriptor of the canonical loop to workshare.
+  /// \param AllocaIP An insertion point for Alloca instructions usable in the
+  ///                 preheader of the loop.
+  /// \param LoopType Information about type of loop worksharing.
+  ///                 It corresponds to type of loop workshare OpenMP pragma.
+  ///
+  /// \returns Point where to insert code after the workshare construct.
+  InsertPointTy applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
+                                         InsertPointTy AllocaIP,
+                                         WorksharingLoopType LoopType);
+
   /// Modifies the canonical loop to be a statically-scheduled workshare loop.
   ///
   /// This takes a \p LoopInfo representing a canonical loop, such as the one
@@ -1012,6 +1044,8 @@ class OpenMPIRBuilder {
   ///                                present in the schedule clause.
   /// \param HasOrderedClause Whether the (parameterless) ordered clause is
   ///                         present.
+  /// \param LoopType Information about type of loop worksharing.
+  ///                 It corresponds to type of loop workshare OpenMP pragma.
   ///
   /// \returns Point where to insert code after the workshare construct.
   InsertPointTy applyWorkshareLoop(
@@ -1020,7 +1054,8 @@ class OpenMPIRBuilder {
       llvm::omp::ScheduleKind SchedKind = llvm::omp::OMP_SCHEDULE_Default,
       Value *ChunkSize = nullptr, bool HasSimdModifier = false,
       bool HasMonotonicModifier = false, bool HasNonmonotonicModifier = false,
-      bool HasOrderedClause = false);
+      bool HasOrderedClause = false,
+      WorksharingLoopType LoopType = WorksharingLoopType::ForStaticLoop);
 
   /// Tile a loop nest.
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 690d6cbaa67b38d..ebebd5304694724 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -2674,11 +2674,255 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
   return {DispatchAfter, DispatchAfter->getFirstInsertionPt()};
 }
 
+// Returns an LLVM function to call for executing an OpenMP static worksharing
+// for loop depending on `type`. Only i32 and i64 are supported by the runtime.
+// Always interpret integers as unsigned similarly to CanonicalLoopInfo.
+static FunctionCallee
+getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
+                            OpenMPIRBuilder::WorksharingLoopType LoopType) {
+  unsigned Bitwidth = Ty->getIntegerBitWidth();
+  Module &M = OMPBuilder->M;
+  switch (LoopType) {
+  case OpenMPIRBuilder::WorksharingLoopType::ForStaticLoop:
+    if (Bitwidth == 32)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
+    if (Bitwidth == 64)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
+    break;
+  case OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop:
+    if (Bitwidth == 32)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
+    if (Bitwidth == 64)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
+    break;
+  case OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop:
+    if (Bitwidth == 32)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
+    if (Bitwidth == 64)
+      return OMPBuilder->getOrCreateRuntimeFunction(
+          M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
+    break;
+  }
+  if (Bitwidth != 32 && Bitwidth != 64)
+    llvm_unreachable("unknown OpenMP loop iterator bitwidth");
+  return FunctionCallee();
+}
+
+// Inserts a call to proper OpenMP Device RTL function which handles
+// loop worksharing.
+static void createTargetLoopWorkshareCall(
+    OpenMPIRBuilder *OMPBuilder, OpenMPIRBuilder::WorksharingLoopType LoopType,
+    BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
+    Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
+  Type *TripCountTy = TripCount->getType();
+  Module &M = OMPBuilder->M;
+  IRBuilder<> &Builder = OMPBuilder->Builder;
+  FunctionCallee RTLFn =
+      getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
+  SmallVector<Value *, 8> RealArgs;
+  RealArgs.push_back(Ident);
+  /*loop body func*/
+  RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
+  /*loop body args*/
+  RealArgs.push_back(LoopBodyArg);
+  /*num of iters*/
+  RealArgs.push_back(TripCount);
+  if (LoopType == OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop) {
+    /*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+                                           ? Builder.getInt32(0)
+                                           : Builder.getInt64(0));
+    Builder.CreateCall(RTLFn, RealArgs);
+    return;
+  }
+  FunctionCallee RTLNumThreads = OMPBuilder->getOrCreateRuntimeFunction(
+      M, omp::RuntimeFunction::OMPRTL_omp_get_num_threads);
+  Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
+  Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
+
+  /*num of threads*/ RealArgs.push_back(
+      Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
+  if (LoopType ==
+      OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop) {
+    /*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+                                           ? Builder.getInt32(0)
+                                           : Builder.getInt64(0));
+  }
+  /*thread chunk */ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
+                                           ? Builder.getInt32(1)
+                                           : Builder.getInt64(1));
+
+  Builder.CreateCall(RTLFn, RealArgs);
+}
+
+static void
+workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
+                            CanonicalLoopInfo *CLI, Value *Ident,
+                            Function &OutlinedFn, Type *ParallelTaskPtr,
+                            const SmallVector<Instruction *, 4> &ToBeDeleted,
+                            OpenMPIRBuilder::WorksharingLoopType LoopType) {
+  IRBuilder<> &Builder = OMPIRBuilder->Builder;
+  BasicBlock *Preheader = CLI->getPreheader();
+  Value *TripCount = CLI->getTripCount();
+
+  // After loop body outling, the loop body contains only set up
+  // of loop body argument structure and the call to the outlined
+  // loop body function. Firstly, we need to move setup of loop body args
+  // into loop preheader.
+  Preheader->splice(std::prev(Preheader->end()), CLI->getBody(),
+                    CLI->getBody()->begin(), std::prev(CLI->getBody()->end()));
+
+  // The next step is to remove the whole loop. We do not it need anymore.
+  // That's why make an unconditional branch from loop preheader to loop
+  // exit block
+  Builder.restoreIP({Preheader, Preheader->end()});
+  Preheader->getTerminator()->eraseFromParent();
+  Builder.CreateBr(CLI->getExit());
+
+  // Delete dead loop blocks
+  OpenMPIRBuilder::OutlineInfo CleanUpInfo;
+  SmallPtrSet<BasicBlock *, 32> RegionBlockSet;
+  SmallVector<BasicBlock *, 32> BlocksToBeRemoved;
+  CleanUpInfo.EntryBB = CLI->getHeader();
+  CleanUpInfo.ExitBB = CLI->getExit();
+  CleanUpInfo.collectBlocks(RegionBlockSet, BlocksToBeRemoved);
+  DeleteDeadBlocks(BlocksToBeRemoved);
+
+  // Find the instruction which corresponds to loop body argument structure
+  // and remove the call to loop body function instruction.
+  Value *LoopBodyArg;
+  for (auto instIt = Preheader->begin(); instIt != Preheader->end(); ++instIt) {
+    if (CallInst *CallInstruction = dyn_cast<CallInst>(instIt)) {
+      if (CallInstruction->getCalledFunction() == &OutlinedFn) {
+        // Check in case no argument structure has been passed.
+        if (CallInstruction->arg_size() > 1)
+          LoopBodyArg = CallInstruction->getArgOperand(1);
+        else
+          LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
+        CallInstruction->eraseFromParent();
+        break;
+      }
+    }
+  }
+
+  createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
+                                LoopBodyArg, ParallelTaskPtr, TripCount,
+                                OutlinedFn);
+
+  for (auto &ToBeDeletedItem : ToBeDeleted)
+    ToBeDeletedItem->eraseFromParent();
+  CLI->invalidate();
+}
+
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
+    DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
+    OpenMPIRBuilder::WorksharingLoopType LoopType) {
+  uint32_t SrcLocStrSize;
+  Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
+  Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
+
+  OutlineInfo OI;
+  OI.OuterAllocaBB = CLI->getPreheader();
+  Function *OuterFn = CLI->getPreheader()->getParent();
+
+  // Instructions which need to be deleted at the end of code generation
+  SmallVector<Instruction *, 4> ToBeDeleted;
+
+  OI.OuterAllocaBB = AllocaIP.getBlock();
+
+  // Mark the body loop as region which needs to be extracted
+  OI.EntryBB = CLI->getBody();
+  OI.ExitBB = CLI->getLatch()->splitBasicBlock(CLI->getLatch()->begin(),
+                                               "omp.prelatch", true);
+
+  // Prepare loop body for extraction
+  Builder.restoreIP({CLI->getPreheader(), CLI->getPreheader()->begin()});
+
+  // Insert new loop counter variable which will be used only in loop
+  // body.
+  AllocaInst *newLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
+  Instruction *newLoopCntLoad =
+      Builder.CreateLoad(CLI->getIndVarType(), newLoopCnt);
+  // New loop counter instructions are redundant in the loop preheader when
+  // code generation for workshare loop is finshed. That's why mark them as
+  // ready for deletion.
+  ToBeDeleted.push_back(newLoopCntLoad);
+  ToBeDeleted.push_back(newLoopCnt);
+
+  // Analyse loop body region. Find all input variables which are used inside
+  // loop body region.
+  SmallPtrSet<BasicBlock *, 32> ParallelRegionBlockSet;
+  SmallVector<BasicBlock *, 32> Blocks;
+  OI.collectBlocks(ParallelRegionBlockSet, Blocks);
+  SmallVector<BasicBlock *, 32> BlocksT(ParallelRegionBlockSet.begin(),
+                                        ParallelRegionBlockSet.end());
+
+  CodeExtractorAnalysisCache CEAC(*OuterFn);
+  CodeExtractor Extractor(Blocks,
+                          /* DominatorTree */ nullptr,
+                          /* AggregateArgs */ true,
+                          /* BlockFrequencyInfo */ nullptr,
+                          /* BranchProbabilityInfo */ nullptr,
+                          /* AssumptionCache */ nullptr,
+                          /* AllowVarArgs */ true,
+                          /* AllowAlloca */ true,
+                          /* AllocationBlock */ CLI->getPreheader(),
+                          /* Suffix */ ".omp_wsloop",
+                          /* AggrArgsIn0AddrSpace */ true);
+
+  BasicBlock *CommonExit = nullptr;
+  SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
+
+  // Find allocas outside the loop body region which are used inside loop
+  // body
+  Extractor.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
+
+  // We need to model loop body region as the function f(cnt, loop_arg).
+  // That's why we replace loop induction variable by the new counter
+  // which will be one of loop body function argument
+  std::vector<User *> Users(CLI->getIndVar()->user_begin(),
+                            CLI->getIndVar()->user_end());
+  for (User *use : Users) {
+    if (Instruction *inst = dyn_cast<Instruction>(use)) {
+      if (ParallelRegionBlockSet.count(inst->getParent())) {
+        inst->replaceUsesOfWith(CLI->getIndVar(), newLoopCntLoad);
+      }
+    }
+  }
+  Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
+  for (Value *Input : Inputs) {
+    // Make sure that loop counter variable is not merged into loop body
+    // function argument structure and it is passed as separate variable
+    if (Input == newLoopCntLoad)
+      OI.ExcludeArgsFromAggregate.push_back(Input);
+  }
+
+  // PostOutline CB is invoked when loop body function is outlined and
+  // loop body is replaced by call to outlined function. We need to add
+  // call to OpenMP device rtl inside loop preheader. OpenMP device rtl
+  // function will handle loop control logic.
+  //
+  OI.PostOutlineCB = [=, ToBeDeletedVec =
+                             std::move(ToBeDeleted)](Function &OutlinedFn) {
+    workshareLoopTargetCallback(this, CLI, Ident, OutlinedFn, ParallelTaskPtr,
+                                ToBeDeletedVec, LoopType);
+  };
+  addOutlineInfo(std::move(OI));
+  return CLI->getAfterIP();
+}
+
 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
     bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind,
     llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier,
-    bool HasNonmonotonicModifier, bool HasOrderedClause) {
+    bool HasNonmonotonicModifier, bool HasOrderedClause,
+    OpenMPIRBuilder::WorksharingLoopType LoopType) {
+  if (Config.isTargetDevice())
+    return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
       SchedKind, ChunkSize, HasSimdModifier, HasMonotonicModifier,
       HasNonmonotonicModifier, HasOrderedClause);
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 2876aa49da402e0..823a70368dff251 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -2228,9 +2228,73 @@ TEST_F(OpenMPIRBuilderTest, UnrollLoopHeuristic) {
   EXPECT_TRUE(getBooleanLoopAttribute(L, "llvm.loop.unroll.enable"));
 }
 
+TEST_F(OpenMPIRBuilderTest, StaticWorkshareLoopTarget) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  std::string oldDLStr = M->getDataLayoutStr();
+  M->setDataLayout(
+      "e-p:64:64-p1:64:64-p2:32:32-p3:32:32-p4:64:64-p5:32:32-p6:32:32-p7:160:"
+      "256:256:32-p8:128:128-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:"
+      "256-v256:256-v512:512-v1024:1024-v2048:2048-n32:64-S32-A5-G1-ni:7:8");
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.Config.IsTargetDevice = true;
+  OMPBuilder.initialize();
+  IRBuilder<> Builder(BB);
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+  InsertPointTy AllocaIP = Builder.saveIP();
+
+  Type *LCTy = Type::getInt32Ty(Ctx);
+  Value *StartVal = ConstantInt::get(LCTy, 10);
+  Value *StopVal = ConstantInt::get(LCTy, 52);
+  Value *StepVal = ConstantInt::get(LCTy, 2);
+  auto LoopBodyGen = [&](InsertPointTy, llvm::Value *) {};
+
+  CanonicalLoopInfo *CLI = OMPBuilder.createCanonicalLoop(
+      Loc, LoopBodyGen, StartVal, StopVal, StepVal, false, false);
+  BasicBlock *Preheader = CLI->getPreheader();
+  Value *TripCount = CLI->getTripCount();
+
+  Builder.SetInsertPoint(BB, BB->getFirstInsertionPt());
+
+  IRBuilder<>::InsertPoint AfterIP = OMPBuilder.applyWorkshareLoop(
+      DL, CLI, AllocaIP, true, OMP_SCHEDULE_Static, nullptr, false, false,
+      false, false, OpenMPIRBuilder::WorksharingLoopType::ForStaticLoop);
+  Builder.restoreIP(AfterIP);
+  Builder.CreateRetVoid();
+
+  OMPBuilder.finalize();
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+
+  CallInst *WorkshareLoopRuntimeCall = nullptr;
+  for (auto Inst = Preheader->begin(); Inst != Preheader->end(); ++Inst) {
+    CallInst *Call = dyn_cast<CallInst>(Inst);
+    if (Call) {
+      if (Call->getCalledFunction()) {
+        if (Call->getCalledFunction()->getName() ==
+            "__kmpc_for_static_loop_4u") {
+          WorkshareLoopRuntimeCall = Call;
+        }
+      }
+    }
+  }
+  EXPECT_NE(WorkshareLoopRuntimeCall, nullptr);
+  // Check that pointer to loop body function is passed as second argument
+  Value *LoopBodyFuncArg = WorkshareLoopRuntimeCall->getArgOperand(1);
+  EXPECT_EQ(Builder.getPtrTy(), LoopBodyFuncArg->getType());
+  Function *ArgFunction = dyn_cast<Function>(LoopBodyFuncArg);
+  EXPECT_NE(ArgFunction, nullptr);
+  EXPECT_EQ(ArgFunction->arg_size(), 1);
+  EXPECT_EQ(ArgFunction->getArg(0)->getType(), TripCount->getType());
+  // Check that no variables except for loop counter are used in loop body
+  EXPECT_EQ(Constant::getNullValue(Builder.getPtrTy()),
+            WorkshareLoopRuntimeCall->getArgOperand(2));
+  // Check loop trip count argument
+  EXPECT_EQ(TripCount, WorkshareLoopRuntimeCall->getArgOperand(3));
+}
+
 TEST_F(OpenMPIRBuilderTest, StaticWorkShareLoop) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.Config.IsTargetDevice = false;
   OMPBuilder.initialize();
   IRBuilder<> Builder(BB);
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
@@ -2331,6 +2395,7 @@ TEST_P(OpenMPIRBuilderTestWithIVBits, StaticChunkedWorkshareLoop) {
 
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.Config.IsTargetDevice = false;
 
   BasicBlock *Body;
   CallInst *Call;
@@ -2405,6 +2470,7 @@ INSTANTIATE_TEST_SUITE_P(IVBits, OpenMPIRBuilderTestWithIVBits,
 TEST_P(OpenMPIRBuilderTestWithParams, DynamicWorkShareLoop) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.Config.IsTargetDevice = false;
   OMPBuilder.initialize();
   IRBuilder<> Builder(BB);
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
@@ -2562,6 +2628,7 @@ INSTANTIATE_TEST_SUITE_P(
 TEST_F(OpenMPIRBuilderTest, DynamicWorkShareLoopOrdered) {
   using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
   OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.Config.IsTargetDevice = false;
   OMPBuilder.initialize();
   IRBuilder<> Builder(BB);
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});

>From f6a550d0202c6fc2e2aea8cd59e4d77fa6efa835 Mon Sep 17 00:00:00 2001
From: Dominik Adamski <dominik.adamski at amd.com>
Date: Fri, 24 Nov 2023 13:40:38 -0600
Subject: [PATCH 2/3] Remove redundant llvm::

---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp       | 4 ++--
 llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index ebebd5304694724..4e8419586494b13 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -2917,8 +2917,8 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
 
 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
     DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
-    bool NeedsBarrier, llvm::omp::ScheduleKind SchedKind,
-    llvm::Value *ChunkSize, bool HasSimdModifier, bool HasMonotonicModifier,
+    bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
+    bool HasSimdModifier, bool HasMonotonicModifier,
     bool HasNonmonotonicModifier, bool HasOrderedClause,
     OpenMPIRBuilder::WorksharingLoopType LoopType) {
   if (Config.isTargetDevice())
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 823a70368dff251..d3a5b6296d8657c 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -2246,7 +2246,7 @@ TEST_F(OpenMPIRBuilderTest, StaticWorkshareLoopTarget) {
   Value *StartVal = ConstantInt::get(LCTy, 10);
   Value *StopVal = ConstantInt::get(LCTy, 52);
   Value *StepVal = ConstantInt::get(LCTy, 2);
-  auto LoopBodyGen = [&](InsertPointTy, llvm::Value *) {};
+  auto LoopBodyGen = [&](InsertPointTy, Value *) {};
 
   CanonicalLoopInfo *CLI = OMPBuilder.createCanonicalLoop(
       Loc, LoopBodyGen, StartVal, StopVal, StepVal, false, false);

>From b8c876d537b3388ac0d53a992eac9f019843dde2 Mon Sep 17 00:00:00 2001
From: Dominik Adamski <dominik.adamski at amd.com>
Date: Wed, 29 Nov 2023 05:22:02 -0600
Subject: [PATCH 3/3] Applied remarks

---
 .../llvm/Frontend/OpenMP/OMPConstants.h       |  10 ++
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  19 +---
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 107 ++++++++----------
 .../Frontend/OpenMPIRBuilderTest.cpp          |  20 ++--
 4 files changed, 74 insertions(+), 82 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
index 32dcdd587f3b31a..f8812e7955b82d0 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPConstants.h
@@ -277,6 +277,16 @@ enum class RTLDependenceKindTy {
   DepOmpAllMem = 0x80,
 };
 
+/// A type of worksharing loop construct
+enum class WorksharingLoopType {
+  // Worksharing `for`-loop
+  ForStaticLoop,
+  // Worksharing `distrbute`-loop
+  DistributeStaticLoop,
+  // Worksharing `distrbute parallel for`-loop
+  DistributeForStaticLoop
+};
+
 } // end namespace omp
 
 } // end namespace llvm
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index e1a9214c5f26598..abbef03d02cb101 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -439,16 +439,6 @@ class OffloadEntriesInfoManager {
 /// Each OpenMP directive has a corresponding public generator method.
 class OpenMPIRBuilder {
 public:
-  /// A type of worksharing loop construct
-  enum class WorksharingLoopType {
-    // Worksharing `for`-loop
-    ForStaticLoop,
-    // Worksharing `distrbute`-loop
-    DistributeStaticLoop,
-    // Worksharing `distrbute parallel for`-loop
-    DistributeForStaticLoop
-  };
-
   /// Create a new OpenMPIRBuilder operating on the given module \p M. This will
   /// not have an effect on \p M (see initialize)
   OpenMPIRBuilder(Module &M)
@@ -913,8 +903,8 @@ class OpenMPIRBuilder {
   /// Modifies the canonical loop to be a statically-scheduled workshare loop
   /// which is executed on the device
   ///
-  /// This takes a \p LoopInfo representing a canonical loop, such as the one
-  /// created by \p createCanonicalLoop and emits additional instructions to
+  /// This takes a \p CLI representing a canonical loop, such as the one
+  /// created by \see createCanonicalLoop and emits additional instructions to
   /// turn it into a workshare loop. In particular, it calls to an OpenMP
   /// runtime function in the preheader to call OpenMP device rtl function
   /// which handles worksharing of loop body interations.
@@ -930,7 +920,7 @@ class OpenMPIRBuilder {
   /// \returns Point where to insert code after the workshare construct.
   InsertPointTy applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
                                          InsertPointTy AllocaIP,
-                                         WorksharingLoopType LoopType);
+                                         omp::WorksharingLoopType LoopType);
 
   /// Modifies the canonical loop to be a statically-scheduled workshare loop.
   ///
@@ -1055,7 +1045,8 @@ class OpenMPIRBuilder {
       Value *ChunkSize = nullptr, bool HasSimdModifier = false,
       bool HasMonotonicModifier = false, bool HasNonmonotonicModifier = false,
       bool HasOrderedClause = false,
-      WorksharingLoopType LoopType = WorksharingLoopType::ForStaticLoop);
+      omp::WorksharingLoopType LoopType =
+          omp::WorksharingLoopType::ForStaticLoop);
 
   /// Tile a loop nest.
   ///
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 4e8419586494b13..38af4e3cf94bad4 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -2679,11 +2679,11 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyStaticChunkedWorkshareLoop(
 // Always interpret integers as unsigned similarly to CanonicalLoopInfo.
 static FunctionCallee
 getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
-                            OpenMPIRBuilder::WorksharingLoopType LoopType) {
+                            WorksharingLoopType LoopType) {
   unsigned Bitwidth = Ty->getIntegerBitWidth();
   Module &M = OMPBuilder->M;
   switch (LoopType) {
-  case OpenMPIRBuilder::WorksharingLoopType::ForStaticLoop:
+  case WorksharingLoopType::ForStaticLoop:
     if (Bitwidth == 32)
       return OMPBuilder->getOrCreateRuntimeFunction(
           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_4u);
@@ -2691,7 +2691,7 @@ getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
       return OMPBuilder->getOrCreateRuntimeFunction(
           M, omp::RuntimeFunction::OMPRTL___kmpc_for_static_loop_8u);
     break;
-  case OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop:
+  case WorksharingLoopType::DistributeStaticLoop:
     if (Bitwidth == 32)
       return OMPBuilder->getOrCreateRuntimeFunction(
           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_4u);
@@ -2699,7 +2699,7 @@ getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
       return OMPBuilder->getOrCreateRuntimeFunction(
           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_static_loop_8u);
     break;
-  case OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop:
+  case WorksharingLoopType::DistributeForStaticLoop:
     if (Bitwidth == 32)
       return OMPBuilder->getOrCreateRuntimeFunction(
           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_4u);
@@ -2708,15 +2708,16 @@ getKmpcForStaticLoopForType(Type *Ty, OpenMPIRBuilder *OMPBuilder,
           M, omp::RuntimeFunction::OMPRTL___kmpc_distribute_for_static_loop_8u);
     break;
   }
-  if (Bitwidth != 32 && Bitwidth != 64)
-    llvm_unreachable("unknown OpenMP loop iterator bitwidth");
-  return FunctionCallee();
+  if (Bitwidth != 32 && Bitwidth != 64) {
+    llvm_unreachable("Unknown OpenMP loop iterator bitwidth");
+  }
+  llvm_unreachable("Unknown type of OpenMP worksharing loop");
 }
 
 // Inserts a call to proper OpenMP Device RTL function which handles
 // loop worksharing.
 static void createTargetLoopWorkshareCall(
-    OpenMPIRBuilder *OMPBuilder, OpenMPIRBuilder::WorksharingLoopType LoopType,
+    OpenMPIRBuilder *OMPBuilder, WorksharingLoopType LoopType,
     BasicBlock *InsertBlock, Value *Ident, Value *LoopBodyArg,
     Type *ParallelTaskPtr, Value *TripCount, Function &LoopBodyFn) {
   Type *TripCountTy = TripCount->getType();
@@ -2726,16 +2727,11 @@ static void createTargetLoopWorkshareCall(
       getKmpcForStaticLoopForType(TripCountTy, OMPBuilder, LoopType);
   SmallVector<Value *, 8> RealArgs;
   RealArgs.push_back(Ident);
-  /*loop body func*/
   RealArgs.push_back(Builder.CreateBitCast(&LoopBodyFn, ParallelTaskPtr));
-  /*loop body args*/
   RealArgs.push_back(LoopBodyArg);
-  /*num of iters*/
   RealArgs.push_back(TripCount);
-  if (LoopType == OpenMPIRBuilder::WorksharingLoopType::DistributeStaticLoop) {
-    /*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
-                                           ? Builder.getInt32(0)
-                                           : Builder.getInt64(0));
+  if (LoopType == WorksharingLoopType::DistributeStaticLoop) {
+    RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
     Builder.CreateCall(RTLFn, RealArgs);
     return;
   }
@@ -2744,17 +2740,12 @@ static void createTargetLoopWorkshareCall(
   Builder.restoreIP({InsertBlock, std::prev(InsertBlock->end())});
   Value *NumThreads = Builder.CreateCall(RTLNumThreads, {});
 
-  /*num of threads*/ RealArgs.push_back(
+  RealArgs.push_back(
       Builder.CreateZExtOrTrunc(NumThreads, TripCountTy, "num.threads.cast"));
-  if (LoopType ==
-      OpenMPIRBuilder::WorksharingLoopType::DistributeForStaticLoop) {
-    /*block chunk*/ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
-                                           ? Builder.getInt32(0)
-                                           : Builder.getInt64(0));
+  RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
+  if (LoopType == WorksharingLoopType::DistributeForStaticLoop) {
+    RealArgs.push_back(ConstantInt::get(TripCountTy, 0));
   }
-  /*thread chunk */ RealArgs.push_back(TripCountTy->getIntegerBitWidth() == 32
-                                           ? Builder.getInt32(1)
-                                           : Builder.getInt64(1));
 
   Builder.CreateCall(RTLFn, RealArgs);
 }
@@ -2764,7 +2755,7 @@ workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
                             CanonicalLoopInfo *CLI, Value *Ident,
                             Function &OutlinedFn, Type *ParallelTaskPtr,
                             const SmallVector<Instruction *, 4> &ToBeDeleted,
-                            OpenMPIRBuilder::WorksharingLoopType LoopType) {
+                            WorksharingLoopType LoopType) {
   IRBuilder<> &Builder = OMPIRBuilder->Builder;
   BasicBlock *Preheader = CLI->getPreheader();
   Value *TripCount = CLI->getTripCount();
@@ -2795,19 +2786,19 @@ workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
   // Find the instruction which corresponds to loop body argument structure
   // and remove the call to loop body function instruction.
   Value *LoopBodyArg;
-  for (auto instIt = Preheader->begin(); instIt != Preheader->end(); ++instIt) {
-    if (CallInst *CallInstruction = dyn_cast<CallInst>(instIt)) {
-      if (CallInstruction->getCalledFunction() == &OutlinedFn) {
-        // Check in case no argument structure has been passed.
-        if (CallInstruction->arg_size() > 1)
-          LoopBodyArg = CallInstruction->getArgOperand(1);
-        else
-          LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
-        CallInstruction->eraseFromParent();
-        break;
-      }
-    }
-  }
+  User *OutlinedFnUser = OutlinedFn.getUniqueUndroppableUser();
+  assert(OutlinedFnUser &&
+         "Expected unique undroppable user of outlined function");
+  CallInst *OutlinedFnCallInstruction = dyn_cast<CallInst>(OutlinedFnUser);
+  assert(OutlinedFnCallInstruction && "Expected outlined function call");
+  assert((OutlinedFnCallInstruction->getParent() == Preheader) &&
+         "Expected outlined function call to be located in loop preheader");
+  // Check in case no argument structure has been passed.
+  if (OutlinedFnCallInstruction->arg_size() > 1)
+    LoopBodyArg = OutlinedFnCallInstruction->getArgOperand(1);
+  else
+    LoopBodyArg = Constant::getNullValue(Builder.getPtrTy());
+  OutlinedFnCallInstruction->eraseFromParent();
 
   createTargetLoopWorkshareCall(OMPIRBuilder, LoopType, Preheader, Ident,
                                 LoopBodyArg, ParallelTaskPtr, TripCount,
@@ -2818,9 +2809,10 @@ workshareLoopTargetCallback(OpenMPIRBuilder *OMPIRBuilder,
   CLI->invalidate();
 }
 
-OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
-    DebugLoc DL, CanonicalLoopInfo *CLI, InsertPointTy AllocaIP,
-    OpenMPIRBuilder::WorksharingLoopType LoopType) {
+OpenMPIRBuilder::InsertPointTy
+OpenMPIRBuilder::applyWorkshareLoopTarget(DebugLoc DL, CanonicalLoopInfo *CLI,
+                                          InsertPointTy AllocaIP,
+                                          WorksharingLoopType LoopType) {
   uint32_t SrcLocStrSize;
   Constant *SrcLocStr = getOrCreateSrcLocStr(DL, SrcLocStrSize);
   Value *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
@@ -2844,14 +2836,14 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
 
   // Insert new loop counter variable which will be used only in loop
   // body.
-  AllocaInst *newLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
-  Instruction *newLoopCntLoad =
-      Builder.CreateLoad(CLI->getIndVarType(), newLoopCnt);
+  AllocaInst *NewLoopCnt = Builder.CreateAlloca(CLI->getIndVarType(), 0, "");
+  Instruction *NewLoopCntLoad =
+      Builder.CreateLoad(CLI->getIndVarType(), NewLoopCnt);
   // New loop counter instructions are redundant in the loop preheader when
   // code generation for workshare loop is finshed. That's why mark them as
   // ready for deletion.
-  ToBeDeleted.push_back(newLoopCntLoad);
-  ToBeDeleted.push_back(newLoopCnt);
+  ToBeDeleted.push_back(NewLoopCntLoad);
+  ToBeDeleted.push_back(NewLoopCnt);
 
   // Analyse loop body region. Find all input variables which are used inside
   // loop body region.
@@ -2884,22 +2876,17 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoopTarget(
   // We need to model loop body region as the function f(cnt, loop_arg).
   // That's why we replace loop induction variable by the new counter
   // which will be one of loop body function argument
-  std::vector<User *> Users(CLI->getIndVar()->user_begin(),
-                            CLI->getIndVar()->user_end());
-  for (User *use : Users) {
-    if (Instruction *inst = dyn_cast<Instruction>(use)) {
-      if (ParallelRegionBlockSet.count(inst->getParent())) {
-        inst->replaceUsesOfWith(CLI->getIndVar(), newLoopCntLoad);
+  for (auto Use = CLI->getIndVar()->user_begin();
+       Use != CLI->getIndVar()->user_end(); ++Use) {
+    if (Instruction *Inst = dyn_cast<Instruction>(*Use)) {
+      if (ParallelRegionBlockSet.count(Inst->getParent())) {
+        Inst->replaceUsesOfWith(CLI->getIndVar(), NewLoopCntLoad);
       }
     }
   }
-  Extractor.findInputsOutputs(Inputs, Outputs, SinkingCands);
-  for (Value *Input : Inputs) {
-    // Make sure that loop counter variable is not merged into loop body
-    // function argument structure and it is passed as separate variable
-    if (Input == newLoopCntLoad)
-      OI.ExcludeArgsFromAggregate.push_back(Input);
-  }
+  // Make sure that loop counter variable is not merged into loop body
+  // function argument structure and it is passed as separate variable
+  OI.ExcludeArgsFromAggregate.push_back(NewLoopCntLoad);
 
   // PostOutline CB is invoked when loop body function is outlined and
   // loop body is replaced by call to outlined function. We need to add
@@ -2920,7 +2907,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::applyWorkshareLoop(
     bool NeedsBarrier, omp::ScheduleKind SchedKind, Value *ChunkSize,
     bool HasSimdModifier, bool HasMonotonicModifier,
     bool HasNonmonotonicModifier, bool HasOrderedClause,
-    OpenMPIRBuilder::WorksharingLoopType LoopType) {
+    WorksharingLoopType LoopType) {
   if (Config.isTargetDevice())
     return applyWorkshareLoopTarget(DL, CLI, AllocaIP, LoopType);
   OMPScheduleType EffectiveScheduleType = computeOpenMPScheduleType(
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index d3a5b6296d8657c..3dfbafbe9067842 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -2257,7 +2257,7 @@ TEST_F(OpenMPIRBuilderTest, StaticWorkshareLoopTarget) {
 
   IRBuilder<>::InsertPoint AfterIP = OMPBuilder.applyWorkshareLoop(
       DL, CLI, AllocaIP, true, OMP_SCHEDULE_Static, nullptr, false, false,
-      false, false, OpenMPIRBuilder::WorksharingLoopType::ForStaticLoop);
+      false, false, WorksharingLoopType::ForStaticLoop);
   Builder.restoreIP(AfterIP);
   Builder.CreateRetVoid();
 
@@ -2265,18 +2265,22 @@ TEST_F(OpenMPIRBuilderTest, StaticWorkshareLoopTarget) {
   EXPECT_FALSE(verifyModule(*M, &errs()));
 
   CallInst *WorkshareLoopRuntimeCall = nullptr;
+  int WorkshareLoopRuntimeCallCnt = 0;
   for (auto Inst = Preheader->begin(); Inst != Preheader->end(); ++Inst) {
     CallInst *Call = dyn_cast<CallInst>(Inst);
-    if (Call) {
-      if (Call->getCalledFunction()) {
-        if (Call->getCalledFunction()->getName() ==
-            "__kmpc_for_static_loop_4u") {
-          WorkshareLoopRuntimeCall = Call;
-        }
-      }
+    if (!Call)
+      continue;
+    if (!Call->getCalledFunction())
+      continue;
+
+    if (Call->getCalledFunction()->getName() == "__kmpc_for_static_loop_4u") {
+      WorkshareLoopRuntimeCall = Call;
+      WorkshareLoopRuntimeCallCnt++;
     }
   }
   EXPECT_NE(WorkshareLoopRuntimeCall, nullptr);
+  // Verify that there is only one call to workshare loop function
+  EXPECT_EQ(WorkshareLoopRuntimeCallCnt, 1);
   // Check that pointer to loop body function is passed as second argument
   Value *LoopBodyFuncArg = WorkshareLoopRuntimeCall->getArgOperand(1);
   EXPECT_EQ(Builder.getPtrTy(), LoopBodyFuncArg->getType());



More information about the llvm-commits mailing list