[clang] [llvm] [mlir] [OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode (PR #116051)

Sergio Afonso via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 14 03:25:44 PST 2025


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116051

>From 0c19f7119c1da0646466b0eb1c3c77faedabaebf Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 27 Nov 2024 11:33:01 +0000
Subject: [PATCH] [OMPIRBuilder] Support runtime number of teams and threads,
 and SPMD mode

This patch introduces a `TargetKernelRuntimeAttrs` structure to hold
host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values
passed to the runtime kernel offloading call.

Additionally, kernel type information is used to influence target device code
generation and the `IsSPMD` flag is replaced by `ExecFlags`, which provide more
granularity.
---
 clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp      |   5 +-
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  38 ++-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 129 +++++---
 .../Frontend/OpenMPIRBuilderTest.cpp          | 281 ++++++++++++++++--
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  12 +-
 5 files changed, 398 insertions(+), 67 deletions(-)

diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index 81993dafae2b03..87c3635ed3f70e 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -20,6 +20,7 @@
 #include "clang/AST/StmtVisitor.h"
 #include "clang/Basic/Cuda.h"
 #include "llvm/ADT/SmallPtrSet.h"
+#include "llvm/Frontend/OpenMP/OMPDeviceConstants.h"
 #include "llvm/Frontend/OpenMP/OMPGridValues.h"
 
 using namespace clang;
@@ -745,7 +746,9 @@ void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
                                         CodeGenFunction &CGF,
                                         EntryFunctionState &EST, bool IsSPMD) {
   llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
-  Attrs.IsSPMD = IsSPMD;
+  Attrs.ExecFlags =
+      IsSPMD ? llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD
+             : llvm::omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC;
   computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
 
   CGBuilderTy &Bld = CGF.Builder;
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 8ca3bc08b5ad49..7eceec3d8cf8f5 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1389,9 +1389,6 @@ class OpenMPIRBuilder {
 
   /// Supporting functions for Reductions CodeGen.
 private:
-  /// Emit the llvm.used metadata.
-  void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List);
-
   /// Get the id of the current thread on the GPU.
   Value *getGPUThreadID();
 
@@ -2013,6 +2010,13 @@ class OpenMPIRBuilder {
   /// Value.
   GlobalValue *createGlobalFlag(unsigned Value, StringRef Name);
 
+  /// Emit the llvm.used metadata.
+  void emitUsed(StringRef Name, ArrayRef<llvm::WeakTrackingVH> List);
+
+  /// Emit the kernel execution mode.
+  GlobalVariable *emitKernelExecutionMode(StringRef KernelName,
+                                          omp::OMPTgtExecModeFlags Mode);
+
   /// Generate control flow and cleanup for cancellation.
   ///
   /// \param CancelFlag Flag indicating if the cancellation is performed.
@@ -2233,13 +2237,34 @@ class OpenMPIRBuilder {
   /// time. The number of max values will be 1 except for the case where
   /// ompx_bare is set.
   struct TargetKernelDefaultAttrs {
-    bool IsSPMD = false;
+    omp::OMPTgtExecModeFlags ExecFlags =
+        omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC;
     SmallVector<int32_t, 3> MaxTeams = {-1};
     int32_t MinTeams = 1;
     SmallVector<int32_t, 3> MaxThreads = {-1};
     int32_t MinThreads = 1;
   };
 
+  /// Container to pass LLVM IR runtime values or constants related to the
+  /// number of teams and threads with which the kernel must be launched, as
+  /// well as the trip count of the loop, if it is an SPMD or Generic-SPMD
+  /// kernel. These must be defined in the host prior to the call to the kernel
+  /// launch OpenMP RTL function.
+  struct TargetKernelRuntimeAttrs {
+    SmallVector<Value *, 3> MaxTeams = {nullptr};
+    Value *MinTeams = nullptr;
+    SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+    SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+    /// 'parallel' construct 'num_threads' clause value, if present and it is an
+    /// SPMD kernel.
+    Value *MaxThreads = nullptr;
+
+    /// Total number of iterations of the SPMD or Generic-SPMD kernel or null if
+    /// it is a generic kernel.
+    Value *LoopTripCount = nullptr;
+  };
+
   /// Data structure that contains the needed information to construct the
   /// kernel args vector.
   struct TargetKernelArgs {
@@ -2971,7 +2996,9 @@ class OpenMPIRBuilder {
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
   /// \param EntryInfo The entry information about the function.
-  /// \param DefaultAttrs Structure containing the default numbers of threads
+  /// \param DefaultAttrs Structure containing the default attributes, including
+  ///        numbers of threads and teams to launch the kernel with.
+  /// \param RuntimeAttrs Structure containing the runtime numbers of threads
   ///        and teams to launch the kernel with.
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
@@ -2987,6 +3014,7 @@ class OpenMPIRBuilder {
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
       TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
+      const TargetKernelRuntimeAttrs &RuntimeAttrs,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
       TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index df9b35ddd80ca4..3242b38502300d 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -830,6 +830,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
   return GV;
 }
 
+void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
+  if (List.empty())
+    return;
+
+  // Convert List to what ConstantArray needs.
+  SmallVector<Constant *, 8> UsedArray;
+  UsedArray.resize(List.size());
+  for (unsigned I = 0, E = List.size(); I != E; ++I)
+    UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+        cast<Constant>(&*List[I]), Builder.getPtrTy());
+
+  if (UsedArray.empty())
+    return;
+  ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
+
+  auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
+                                ConstantArray::get(ATy, UsedArray), Name);
+
+  GV->setSection("llvm.metadata");
+}
+
+GlobalVariable *
+OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
+                                         OMPTgtExecModeFlags Mode) {
+  auto *Int8Ty = Builder.getInt8Ty();
+  auto *GVMode = new GlobalVariable(
+      M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
+      ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode"));
+  GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
+  return GVMode;
+}
+
 Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
                                             uint32_t SrcLocStrSize,
                                             IdentFlag LocFlags,
@@ -2260,28 +2292,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
   return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
 }
 
-void OpenMPIRBuilder::emitUsed(StringRef Name,
-                               std::vector<WeakTrackingVH> &List) {
-  if (List.empty())
-    return;
-
-  // Convert List to what ConstantArray needs.
-  SmallVector<Constant *, 8> UsedArray;
-  UsedArray.resize(List.size());
-  for (unsigned I = 0, E = List.size(); I != E; ++I)
-    UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
-        cast<Constant>(&*List[I]), Builder.getPtrTy());
-
-  if (UsedArray.empty())
-    return;
-  ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
-
-  auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
-                                ConstantArray::get(ATy, UsedArray), Name);
-
-  GV->setSection("llvm.metadata");
-}
-
 Value *OpenMPIRBuilder::getGPUThreadID() {
   return Builder.CreateCall(
       getOrCreateRuntimeFunction(M,
@@ -6131,10 +6141,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
   uint32_t SrcLocStrSize;
   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
   Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
-  Constant *IsSPMDVal = ConstantInt::getSigned(
-      Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
-  Constant *UseGenericStateMachineVal =
-      ConstantInt::getSigned(Int8, !Attrs.IsSPMD);
+  Constant *IsSPMDVal = ConstantInt::getSigned(Int8, Attrs.ExecFlags);
+  Constant *UseGenericStateMachineVal = ConstantInt::getSigned(
+      Int8, Attrs.ExecFlags != omp::OMP_TGT_EXEC_MODE_SPMD);
   Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
   Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
 
@@ -6765,6 +6774,12 @@ static Expected<Function *> createOutlinedFunction(
   auto Func =
       Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
 
+  if (OMPBuilder.Config.isTargetDevice()) {
+    Value *ExecMode =
+        OMPBuilder.emitKernelExecutionMode(FuncName, DefaultAttrs.ExecFlags);
+    OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
+  }
+
   // Save insert point.
   IRBuilder<>::InsertPointGuard IPG(Builder);
   // If there's a DISubprogram associated with current function, then
@@ -7312,6 +7327,7 @@ static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                OpenMPIRBuilder::InsertPointTy AllocaIP,
                const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+               const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
                Function *OutlinedFn, Constant *OutlinedFnID,
                SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7393,11 +7409,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                                          /*ForEndCall=*/false);
 
   SmallVector<Value *, 3> NumTeamsC;
+  for (auto [DefaultVal, RuntimeVal] :
+       zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+    NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+  // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+  // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+  auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+    if (Clause)
+      Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+                                     /*isSigned=*/false);
+    return Clause;
+  };
+  auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+    if (Clause)
+      Result = Result
+                   ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+                                          Result, Clause)
+                   : Clause;
+  };
+
+  // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+  // the NUM_THREADS clause is overriden by THREAD_LIMIT.
   SmallVector<Value *, 3> NumThreadsC;
-  for (auto V : DefaultAttrs.MaxTeams)
-    NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
-  for (auto V : DefaultAttrs.MaxThreads)
-    NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+  Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+                                ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+                                : nullptr;
+
+  for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
+                                              RuntimeAttrs.TargetThreadLimit)) {
+    Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+    Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+    CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+    CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+    NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+  }
 
   unsigned NumTargetItems = Info.NumberOfPtrs;
   // TODO: Use correct device ID
@@ -7406,14 +7454,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
   Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
   Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
                                              llvm::omp::IdentFlag(0), 0);
-  // TODO: Use correct NumIterations
-  Value *NumIterations = Builder.getInt64(0);
+
+  Value *TripCount = RuntimeAttrs.LoopTripCount
+                         ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+                                                 Builder.getInt64Ty(),
+                                                 /*isSigned=*/false)
+                         : Builder.getInt64(0);
+
   // TODO: Use correct DynCGGroupMem
   Value *DynCGGroupMem = Builder.getInt32(0);
 
-  KArgs = OpenMPIRBuilder::TargetKernelArgs(
-      NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
-      DynCGGroupMem, HasNoWait);
+  KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+                                            NumTeamsC, NumThreadsC,
+                                            DynCGGroupMem, HasNoWait);
 
   // The presence of certain clauses on the target directive require the
   // explicit generation of the target task.
@@ -7438,6 +7491,7 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
+    const TargetKernelRuntimeAttrs &RuntimeAttrs,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7462,8 +7516,9 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
-                   OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+                   OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+                   HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index 04ecd7ef327d5a..11f13beb9865c0 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6170,7 +6170,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   OMPBuilder.setConfig(Config);
   F->setName("func");
   IRBuilder<> Builder(BB);
-  auto Int32Ty = Builder.getInt32Ty();
+  auto *Int32Ty = Builder.getInt32Ty();
 
   AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
   AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6229,16 +6229,22 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
 
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
-      /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
-      /*MinThreads=*/0};
+      /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
+      /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+  RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+  RuntimeAttrs.MaxThreads = Builder.getInt32(40);
 
   ASSERT_EXPECTED_INIT(
       OpenMPIRBuilder::InsertPointTy, AfterIP,
       OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
-                              Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
-                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+                              Builder.saveIP(), EntryInfo, DefaultAttrs,
+                              RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
+                              SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
+
   OMPBuilder.finalize();
   Builder.CreateRetVoid();
 
@@ -6256,6 +6262,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   StringRef FunctionName = KernelLaunchFunc->getName();
   EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
 
+  // Check num_teams and num_threads in call arguments
+  EXPECT_TRUE(Call->arg_size() >= 4);
+  Value *NumTeamsArg = Call->getArgOperand(2);
+  EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+  EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+  Value *NumThreadsArg = Call->getArgOperand(3);
+  EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+  EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+  // Check num_teams and num_threads kernel arguments (use number 5 starting
+  // from the end and counting the call to __tgt_target_kernel as the first use)
+  Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+  EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+  Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+  Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+  Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+  auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+  EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+  EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+  Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+  Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+  Value *NumThreadsStoreArg =
+      cast<StoreInst>(NumThreadsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+  auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg);
+  EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+  EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
   // Check the fallback call
   BasicBlock *FallbackBlock = Branch->getSuccessor(0);
   Iter = FallbackBlock->rbegin();
@@ -6343,15 +6386,16 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
                                    F->getEntryBlock().getFirstInsertionPt());
   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
                                   /*Line=*/3, /*Count=*/0);
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
-      /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
-      /*MinThreads=*/0};
+      /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
 
-  ASSERT_EXPECTED_INIT(
-      OpenMPIRBuilder::InsertPointTy, AfterIP,
-      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, DefaultAttrs, CapturedArgs,
-                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+  ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
+                       OMPBuilder.createTarget(
+                           Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+                           EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
+                           GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   Builder.CreateRetVoid();
@@ -6435,6 +6479,204 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   auto *ExitBlock = EntryBlockBranch->getSuccessor(1);
   EXPECT_EQ(ExitBlock->getName(), "worker.exit");
   EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
+
+  // Check global exec_mode.
+  GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+  EXPECT_NE(Used, nullptr);
+  Constant *UsedInit = Used->getInitializer();
+  EXPECT_NE(UsedInit, nullptr);
+  EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+  auto *UsedInitData = cast<ConstantArray>(UsedInit);
+  EXPECT_EQ(1U, UsedInitData->getNumOperands());
+  Constant *ExecMode = UsedInitData->getOperand(0);
+  EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+  Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+  EXPECT_NE(ExecModeValue, nullptr);
+  EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+  EXPECT_EQ(OMP_TGT_EXEC_MODE_GENERIC,
+            cast<ConstantInt>(ExecModeValue)->getZExtValue());
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  OpenMPIRBuilderConfig Config(/*IsTargetDevice=*/false, /*IsGPU=*/false,
+                               /*OpenMPOffloadMandatory=*/false,
+                               /*HasRequiresReverseOffload=*/false,
+                               /*HasRequiresUnifiedAddress=*/false,
+                               /*HasRequiresUnifiedSharedMemory=*/false,
+                               /*HasRequiresDynamicAllocators=*/false);
+  OMPBuilder.setConfig(Config);
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  auto BodyGenCB = [&](InsertPointTy,
+                       InsertPointTy CodeGenIP) -> InsertPointTy {
+    Builder.restoreIP(CodeGenIP);
+    return Builder.saveIP();
+  };
+
+  auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
+                                 OpenMPIRBuilder::InsertPointTy,
+                                 OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    return Builder.saveIP();
+  };
+
+  SmallVector<Value *> Inputs;
+  OpenMPIRBuilder::MapInfosTy CombinedInfos;
+  auto GenMapInfoCB =
+      [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & {
+    return CombinedInfos;
+  };
+
+  TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
+  OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD,
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  RuntimeAttrs.LoopTripCount = Builder.getInt64(1000);
+
+  ASSERT_EXPECTED_INIT(
+      OpenMPIRBuilder::InsertPointTy, AfterIP,
+      OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
+                              Builder.saveIP(), EntryInfo, DefaultAttrs,
+                              RuntimeAttrs, Inputs, GenMapInfoCB, BodyGenCB,
+                              SimpleArgAccessorCB));
+  Builder.restoreIP(AfterIP);
+
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  // Check the kernel launch sequence
+  auto Iter = F->getEntryBlock().rbegin();
+  EXPECT_TRUE(isa<BranchInst>(&*(Iter)));
+  BranchInst *Branch = dyn_cast<BranchInst>(&*(Iter));
+  EXPECT_TRUE(isa<CmpInst>(&*(++Iter)));
+  EXPECT_TRUE(isa<CallInst>(&*(++Iter)));
+  CallInst *Call = dyn_cast<CallInst>(&*(Iter));
+
+  // Check that the kernel launch function is called
+  Function *KernelLaunchFunc = Call->getCalledFunction();
+  EXPECT_NE(KernelLaunchFunc, nullptr);
+  StringRef FunctionName = KernelLaunchFunc->getName();
+  EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+
+  // Check the trip count kernel argument (use number 5 starting from the end
+  // and counting the call to __tgt_target_kernel as the first use)
+  Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+  EXPECT_TRUE(KernelArgs->getNumUses() >= 6);
+  Value *TripCountGetElemPtr = *std::next(KernelArgs->user_begin(), 5);
+  EXPECT_TRUE(isa<GetElementPtrInst>(TripCountGetElemPtr));
+  Value *TripCountStore = TripCountGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(TripCountStore));
+  Value *TripCountStoreArg = cast<StoreInst>(TripCountStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantInt>(TripCountStoreArg));
+  EXPECT_EQ(1000U, cast<ConstantInt>(TripCountStoreArg)->getZExtValue());
+
+  // Check the fallback call
+  BasicBlock *FallbackBlock = Branch->getSuccessor(0);
+  Iter = FallbackBlock->rbegin();
+  CallInst *FCall = dyn_cast<CallInst>(&*(++Iter));
+  // 'F' has a dummy DISubprogram which causes OutlinedFunc to also
+  // have a DISubprogram. In this case, the call to OutlinedFunc needs
+  // to have a debug loc, otherwise verifier will complain.
+  FCall->setDebugLoc(DL);
+  EXPECT_NE(FCall, nullptr);
+
+  // Check that the outlined function exists with the expected prefix
+  Function *OutlinedFunc = FCall->getCalledFunction();
+  EXPECT_NE(OutlinedFunc, nullptr);
+  StringRef FunctionName2 = OutlinedFunc->getName();
+  EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.setConfig(
+      OpenMPIRBuilderConfig(/*IsTargetDevice=*/true, /*IsGPU=*/false,
+                            /*OpenMPOffloadMandatory=*/false,
+                            /*HasRequiresReverseOffload=*/false,
+                            /*HasRequiresUnifiedAddress=*/false,
+                            /*HasRequiresUnifiedSharedMemory=*/false,
+                            /*HasRequiresDynamicAllocators=*/false));
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+  Function *OutlinedFn = nullptr;
+  SmallVector<Value *> CapturedArgs;
+
+  auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
+                                 OpenMPIRBuilder::InsertPointTy,
+                                 OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    return Builder.saveIP();
+  };
+
+  OpenMPIRBuilder::MapInfosTy CombinedInfos;
+  auto GenMapInfoCB =
+      [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & {
+    return CombinedInfos;
+  };
+
+  auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy,
+                       OpenMPIRBuilder::InsertPointTy CodeGenIP)
+      -> OpenMPIRBuilder::InsertPointTy {
+    Builder.restoreIP(CodeGenIP);
+    OutlinedFn = CodeGenIP.getBlock()->getParent();
+    return Builder.saveIP();
+  };
+
+  IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(),
+                                   F->getEntryBlock().getFirstInsertionPt());
+  TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
+                                  /*Line=*/3, /*Count=*/0);
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_SPMD,
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+
+  ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
+                       OMPBuilder.createTarget(
+                           Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+                           EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
+                           GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+  Builder.restoreIP(AfterIP);
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  // Check outlined function
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+  EXPECT_NE(OutlinedFn, nullptr);
+  EXPECT_NE(F, OutlinedFn);
+
+  EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
+  // Account for the "implicit" first argument.
+  EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
+  EXPECT_EQ(OutlinedFn->arg_size(), 1U);
+
+  // Check global exec_mode.
+  GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+  EXPECT_NE(Used, nullptr);
+  Constant *UsedInit = Used->getInitializer();
+  EXPECT_NE(UsedInit, nullptr);
+  EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+  auto *UsedInitData = cast<ConstantArray>(UsedInit);
+  EXPECT_EQ(1U, UsedInitData->getNumOperands());
+  Constant *ExecMode = UsedInitData->getOperand(0);
+  EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+  Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+  EXPECT_NE(ExecModeValue, nullptr);
+  EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+  EXPECT_EQ(OMP_TGT_EXEC_MODE_SPMD,
+            cast<ConstantInt>(ExecModeValue)->getZExtValue());
 }
 
 TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
@@ -6502,15 +6744,16 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
                                    F->getEntryBlock().getFirstInsertionPt());
   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
                                   /*Line=*/3, /*Count=*/0);
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
-      /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
-      /*MinThreads=*/0};
+      /*ExecFlags=*/omp::OMPTgtExecModeFlags::OMP_TGT_EXEC_MODE_GENERIC,
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
 
-  ASSERT_EXPECTED_INIT(
-      OpenMPIRBuilder::InsertPointTy, AfterIP,
-      OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, DefaultAttrs, CapturedArgs,
-                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
+  ASSERT_EXPECTED_INIT(OpenMPIRBuilder::InsertPointTy, AfterIP,
+                       OMPBuilder.createTarget(
+                           Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
+                           EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs,
+                           GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   Builder.CreateRetVoid();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 25b0ffe2ced6de..5c36187540690e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4115,10 +4115,12 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
                                         allocaIP, codeGenIP);
   };
 
-  // TODO: Populate default attributes based on the construct and clauses.
+  // TODO: Populate default and runtime attributes based on the construct and
+  // clauses.
+  llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
   llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
-      /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
-      /*MinThreads=*/0};
+      /*ExecFlags=*/llvm::omp::OMP_TGT_EXEC_MODE_GENERIC, /*MaxTeams=*/{-1},
+      /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
 
   llvm::SmallVector<llvm::Value *, 4> kernelInput;
   for (size_t i = 0; i < mapVars.size(); ++i) {
@@ -4143,8 +4145,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
           ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
-          defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds,
-          targetOp.getNowait());
+          defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB, bodyCB,
+          argAccessorCB, dds, targetOp.getNowait());
 
   if (failed(handleError(afterIP, opInst)))
     return failure();



More information about the llvm-commits mailing list