[llvm-branch-commits] [llvm] [mlir] [OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode (PR #116051)

Sergio Afonso via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Nov 27 04:26:12 PST 2024


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116051

>From f120456cd3200ff82cca63570272d57ba909fe87 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 27 Nov 2024 11:33:01 +0000
Subject: [PATCH 1/2] [OMPIRBuilder] Support runtime number of teams and
 threads, and SPMD mode

This patch introduces a `TargetKernelRuntimeAttrs` structure to hold
host-evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values
passed to the runtime kernel offloading call.

Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to
influence target device code generation.
---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  26 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 125 +++++++--
 .../Frontend/OpenMPIRBuilderTest.cpp          | 256 +++++++++++++++++-
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  10 +-
 4 files changed, 383 insertions(+), 34 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index da450ef5adbc14..a85f41e586c514 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
     int32_t MinThreads = 1;
   };
 
+  /// Container to pass LLVM IR runtime values or constants related to the
+  /// number of teams and threads with which the kernel must be launched, as
+  /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
+  /// must be defined in the host prior to the call to the kernel launch OpenMP
+  /// RTL function.
+  struct TargetKernelRuntimeAttrs {
+    SmallVector<Value *, 3> MaxTeams = {nullptr};
+    Value *MinTeams = nullptr;
+    SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+    SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+    /// 'parallel' construct 'num_threads' clause value, if present and it is a
+    /// target SPMD kernel.
+    Value *MaxThreads = nullptr;
+
+    /// Total number of iterations of the target SPMD kernel or null if it is a
+    /// generic kernel.
+    Value *LoopTripCount = nullptr;
+  };
+
   /// Data structure that contains the needed information to construct the
   /// kernel args vector.
   struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc where the target data construct was encountered.
   /// \param IsOffloadEntry whether it is an offload entry.
+  /// \param IsSPMD whether it is a target SPMD kernel.
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
   /// \param EntryInfo The entry information about the function.
   /// \param DefaultAttrs Structure containing the default numbers of threads
   ///        and teams to launch the kernel with.
+  /// \param RuntimeAttrs Structure containing the runtime numbers of threads
+  ///        and teams to launch the kernel with.
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
   // dependency information as passed in the depend clause
   // \param HasNowait Whether the target construct has a `nowait` clause or not.
   InsertPointOrErrorTy createTarget(
-      const LocationDescription &Loc, bool IsOffloadEntry,
+      const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
       TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
+      const TargetKernelRuntimeAttrs &RuntimeAttrs,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
       TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 302d363965c940..09f794ccf734b3 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
   return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
 }
 
+static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
+                     Module &M) {
+  if (List.empty())
+    return;
+
+  Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
+
+  // Convert List to what ConstantArray needs.
+  SmallVector<Constant *, 8> UsedArray;
+  UsedArray.reserve(List.size());
+  for (auto Item : List)
+    UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+        cast<Constant>(&*Item), PtrTy));
+
+  ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
+  auto *GV =
+      new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
+                         llvm::ConstantArray::get(ArrTy, UsedArray), Name);
+
+  GV->setSection("llvm.metadata");
+}
+
+static void
+emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+                  StringRef FunctionName, OMPTgtExecModeFlags Mode,
+                  std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
+  auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
+  auto *GVMode = new llvm::GlobalVariable(
+      OMPBuilder.M, Int8Ty, /*isConstant=*/true,
+      llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
+      Twine(FunctionName, "_exec_mode"));
+  GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
+  LLVMCompilerUsed.emplace_back(GVMode);
+}
+
 static Expected<Function *> createOutlinedFunction(
-    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,15 @@ static Expected<Function *> createOutlinedFunction(
   auto Func =
       Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
 
+  if (OMPBuilder.Config.isTargetDevice()) {
+    std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
+    emitExecutionMode(OMPBuilder, Builder, FuncName,
+                      IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
+                             : OMP_TGT_EXEC_MODE_GENERIC,
+                      LLVMCompilerUsed);
+    emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+  }
+
   // Save insert point.
   IRBuilder<>::InsertPointGuard IPG(Builder);
   // If there's a DISubprogram associated with current function, then
@@ -6798,7 +6842,7 @@ static Expected<Function *> createOutlinedFunction(
   // Insert target init call in the device compilation pass.
   if (OMPBuilder.Config.isTargetDevice())
     Builder.restoreIP(
-        OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
+        OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
 
   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
 
@@ -6995,7 +7039,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
 
 static Error emitTargetOutlinedFunction(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
-    TargetRegionEntryInfo &EntryInfo,
+    bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     Function *&OutlinedFn, Constant *&OutlinedFnID,
     SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7048,7 @@ static Error emitTargetOutlinedFunction(
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
       [&](StringRef EntryFnName) {
-        return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+        return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
                                       EntryFnName, Inputs, CBFunc,
                                       ArgAccessorFuncCB);
       };
@@ -7304,6 +7348,7 @@ static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                OpenMPIRBuilder::InsertPointTy AllocaIP,
                const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+               const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
                Function *OutlinedFn, Constant *OutlinedFnID,
                SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7430,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                                          /*ForEndCall=*/false);
 
   SmallVector<Value *, 3> NumTeamsC;
+  for (auto [DefaultVal, RuntimeVal] :
+       zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+    NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+  // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+  // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+  auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+    if (Clause)
+      Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+                                     /*isSigned=*/false);
+    return Clause;
+  };
+  auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+    if (Clause)
+      Result = Result
+                   ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+                                          Result, Clause)
+                   : Clause;
+  };
+
+  // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+  // the NUM_THREADS clause is overriden by THREAD_LIMIT.
   SmallVector<Value *, 3> NumThreadsC;
-  for (auto V : DefaultAttrs.MaxTeams)
-    NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
-  for (auto V : DefaultAttrs.MaxThreads)
-    NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+  Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+                                ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+                                : nullptr;
+
+  for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
+           RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+    Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+    Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+    CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+    CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+    NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+  }
 
   unsigned NumTargetItems = Info.NumberOfPtrs;
   // TODO: Use correct device ID
@@ -7398,14 +7475,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
   Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
   Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
                                              llvm::omp::IdentFlag(0), 0);
-  // TODO: Use correct NumIterations
-  Value *NumIterations = Builder.getInt64(0);
+
+  Value *TripCount = RuntimeAttrs.LoopTripCount
+                         ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+                                                 Builder.getInt64Ty(),
+                                                 /*isSigned=*/false)
+                         : Builder.getInt64(0);
+
   // TODO: Use correct DynCGGroupMem
   Value *DynCGGroupMem = Builder.getInt32(0);
 
-  KArgs = OpenMPIRBuilder::TargetKernelArgs(
-      NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
-      DynCGGroupMem, HasNoWait);
+  KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+                                            NumTeamsC, NumThreadsC,
+                                            DynCGGroupMem, HasNoWait);
 
   // The presence of certain clauses on the target directive require the
   // explicit generation of the target task.
@@ -7427,13 +7509,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 }
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
-    const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
-    InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
+    const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
+    InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+    TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
+    const TargetKernelRuntimeAttrs &RuntimeAttrs,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
     SmallVector<DependData> Dependencies, bool HasNowait) {
+  assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
+         "trip count not expected if IsSPMD=false");
 
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -7446,16 +7532,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // the target region itself is generated using the callbacks CBFunc
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
-          *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
-          OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
+          *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
+          OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
-                   OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+                   OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+                   HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index b0688d6215e42d..a8c786b5886afe 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6123,7 +6123,7 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   OMPBuilder.setConfig(Config);
   F->setName("func");
   IRBuilder<> Builder(BB);
-  auto Int32Ty = Builder.getInt32Ty();
+  auto *Int32Ty = Builder.getInt32Ty();
 
   AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
   AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6183,11 +6183,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
-      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
-                              Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
-                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+  RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+  RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+  RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
+      Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+      GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
   OMPBuilder.finalize();
@@ -6207,6 +6211,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   StringRef FunctionName = KernelLaunchFunc->getName();
   EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
 
+  // Check num_teams and num_threads in call arguments
+  EXPECT_TRUE(Call->arg_size() >= 4);
+  Value *NumTeamsArg = Call->getArgOperand(2);
+  EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+  EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+  Value *NumThreadsArg = Call->getArgOperand(3);
+  EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+  EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+  // Check num_teams and num_threads kernel arguments (use number 5 starting
+  // from the end and counting the call to __tgt_target_kernel as the first use)
+  Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+  EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+  Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+  Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+  Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+  auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+  EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+  EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+  Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+  Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+  Value *NumThreadsStoreArg =
+      cast<StoreInst>(NumThreadsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+  auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg);
+  EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+  EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
   // Check the fallback call
   BasicBlock *FallbackBlock = Branch->getSuccessor(0);
   Iter = FallbackBlock->rbegin();
@@ -6297,9 +6338,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
 
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
-      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
-      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+      EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+      BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
@@ -6378,6 +6421,197 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   auto *ExitBlock = EntryBlockBranch->getSuccessor(1);
   EXPECT_EQ(ExitBlock->getName(), "worker.exit");
   EXPECT_TRUE(isa<ReturnInst>(ExitBlock->getFirstNonPHI()));
+
+  // Check global exec_mode.
+  GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+  EXPECT_NE(Used, nullptr);
+  Constant *UsedInit = Used->getInitializer();
+  EXPECT_NE(UsedInit, nullptr);
+  EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+  auto *UsedInitData = cast<ConstantArray>(UsedInit);
+  EXPECT_EQ(1U, UsedInitData->getNumOperands());
+  Constant *ExecMode = UsedInitData->getOperand(0);
+  EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+  Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+  EXPECT_NE(ExecModeValue, nullptr);
+  EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+  EXPECT_EQ(OMP_TGT_EXEC_MODE_GENERIC,
+            cast<ConstantInt>(ExecModeValue)->getZExtValue());
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
+  using InsertPointTy = OpenMPIRBuilder::InsertPointTy;
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.initialize();
+  OpenMPIRBuilderConfig Config(/*IsTargetDevice=*/false, /*IsGPU=*/false,
+                               /*OpenMPOffloadMandatory=*/false,
+                               /*HasRequiresReverseOffload=*/false,
+                               /*HasRequiresUnifiedAddress=*/false,
+                               /*HasRequiresUnifiedSharedMemory=*/false,
+                               /*HasRequiresDynamicAllocators=*/false);
+  OMPBuilder.setConfig(Config);
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+
+  auto BodyGenCB = [&](InsertPointTy,
+                       InsertPointTy CodeGenIP) -> InsertPointTy {
+    Builder.restoreIP(CodeGenIP);
+    return Builder.saveIP();
+  };
+
+  auto SimpleArgAccessorCB =
+      [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
+          llvm::OpenMPIRBuilder::InsertPointTy,
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+        Builder.restoreIP(CodeGenIP);
+        return Builder.saveIP();
+      };
+
+  llvm::SmallVector<llvm::Value *> Inputs;
+  llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
+  auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
+      -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+
+  TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
+  OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+  RuntimeAttrs.LoopTripCount = Builder.getInt64(1000);
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, Builder.saveIP(),
+      Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+      GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+  assert(AfterIP && "unexpected error");
+  Builder.restoreIP(*AfterIP);
+  OMPBuilder.finalize();
+  Builder.CreateRetVoid();
+
+  // Check the kernel launch sequence
+  auto Iter = F->getEntryBlock().rbegin();
+  EXPECT_TRUE(isa<BranchInst>(&*(Iter)));
+  BranchInst *Branch = dyn_cast<BranchInst>(&*(Iter));
+  EXPECT_TRUE(isa<CmpInst>(&*(++Iter)));
+  EXPECT_TRUE(isa<CallInst>(&*(++Iter)));
+  CallInst *Call = dyn_cast<CallInst>(&*(Iter));
+
+  // Check that the kernel launch function is called
+  Function *KernelLaunchFunc = Call->getCalledFunction();
+  EXPECT_NE(KernelLaunchFunc, nullptr);
+  StringRef FunctionName = KernelLaunchFunc->getName();
+  EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
+
+  // Check the trip count kernel argument (use number 5 starting from the end
+  // and counting the call to __tgt_target_kernel as the first use)
+  Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+  EXPECT_TRUE(KernelArgs->getNumUses() >= 6);
+  Value *TripCountGetElemPtr = *std::next(KernelArgs->user_begin(), 5);
+  EXPECT_TRUE(isa<GetElementPtrInst>(TripCountGetElemPtr));
+  Value *TripCountStore = TripCountGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(TripCountStore));
+  Value *TripCountStoreArg = cast<StoreInst>(TripCountStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantInt>(TripCountStoreArg));
+  EXPECT_EQ(1000U, cast<ConstantInt>(TripCountStoreArg)->getZExtValue());
+
+  // Check the fallback call
+  BasicBlock *FallbackBlock = Branch->getSuccessor(0);
+  Iter = FallbackBlock->rbegin();
+  CallInst *FCall = dyn_cast<CallInst>(&*(++Iter));
+  // 'F' has a dummy DISubprogram which causes OutlinedFunc to also
+  // have a DISubprogram. In this case, the call to OutlinedFunc needs
+  // to have a debug loc, otherwise verifier will complain.
+  FCall->setDebugLoc(DL);
+  EXPECT_NE(FCall, nullptr);
+
+  // Check that the outlined function exists with the expected prefix
+  Function *OutlinedFunc = FCall->getCalledFunction();
+  EXPECT_NE(OutlinedFunc, nullptr);
+  StringRef FunctionName2 = OutlinedFunc->getName();
+  EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
+
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+}
+
+TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
+  OpenMPIRBuilder OMPBuilder(*M);
+  OMPBuilder.setConfig(
+      OpenMPIRBuilderConfig(/*IsTargetDevice=*/true, /*IsGPU=*/false,
+                            /*OpenMPOffloadMandatory=*/false,
+                            /*HasRequiresReverseOffload=*/false,
+                            /*HasRequiresUnifiedAddress=*/false,
+                            /*HasRequiresUnifiedSharedMemory=*/false,
+                            /*HasRequiresDynamicAllocators=*/false));
+  OMPBuilder.initialize();
+  F->setName("func");
+  IRBuilder<> Builder(BB);
+  OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
+
+  Function *OutlinedFn = nullptr;
+  llvm::SmallVector<llvm::Value *> CapturedArgs;
+
+  auto SimpleArgAccessorCB =
+      [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
+          llvm::OpenMPIRBuilder::InsertPointTy,
+          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+        Builder.restoreIP(CodeGenIP);
+        return Builder.saveIP();
+      };
+
+  llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
+  auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
+      -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+
+  auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy,
+                       OpenMPIRBuilder::InsertPointTy CodeGenIP)
+      -> OpenMPIRBuilder::InsertPointTy {
+    Builder.restoreIP(CodeGenIP);
+    OutlinedFn = CodeGenIP.getBlock()->getParent();
+    return Builder.saveIP();
+  };
+
+  IRBuilder<>::InsertPoint EntryIP(&F->getEntryBlock(),
+                                   F->getEntryBlock().getFirstInsertionPt());
+  TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
+                                  /*Line=*/3, /*Count=*/0);
+
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/true, EntryIP, EntryIP,
+      EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+      BodyGenCB, SimpleArgAccessorCB);
+  assert(AfterIP && "unexpected error");
+  Builder.restoreIP(*AfterIP);
+
+  Builder.CreateRetVoid();
+  OMPBuilder.finalize();
+
+  // Check outlined function
+  EXPECT_FALSE(verifyModule(*M, &errs()));
+  EXPECT_NE(OutlinedFn, nullptr);
+  EXPECT_NE(F, OutlinedFn);
+
+  EXPECT_TRUE(OutlinedFn->hasWeakODRLinkage());
+  // Account for the "implicit" first argument.
+  EXPECT_EQ(OutlinedFn->getName(), "__omp_offloading_1_2_parent_l3");
+  EXPECT_EQ(OutlinedFn->arg_size(), 1U);
+
+  // Check global exec_mode.
+  GlobalVariable *Used = M->getGlobalVariable("llvm.compiler.used");
+  EXPECT_NE(Used, nullptr);
+  Constant *UsedInit = Used->getInitializer();
+  EXPECT_NE(UsedInit, nullptr);
+  EXPECT_TRUE(isa<ConstantArray>(UsedInit));
+  auto *UsedInitData = cast<ConstantArray>(UsedInit);
+  EXPECT_EQ(1U, UsedInitData->getNumOperands());
+  Constant *ExecMode = UsedInitData->getOperand(0);
+  EXPECT_TRUE(isa<GlobalVariable>(ExecMode));
+  Constant *ExecModeValue = cast<GlobalVariable>(ExecMode)->getInitializer();
+  EXPECT_NE(ExecModeValue, nullptr);
+  EXPECT_TRUE(isa<ConstantInt>(ExecModeValue));
+  EXPECT_EQ(OMP_TGT_EXEC_MODE_SPMD,
+            cast<ConstantInt>(ExecModeValue)->getZExtValue());
 }
 
 TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
@@ -6448,9 +6682,11 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
 
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
-      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
-      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+      EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+      BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d3c3839accb7e7..9bdf3e11496f3a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3936,9 +3936,11 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
                                         allocaIP, codeGenIP);
   };
 
-  // TODO: Populate default attributes based on the construct and clauses.
+  // TODO: Populate default and runtime attributes based on the construct and
+  // clauses.
   llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  llvm::OpenMPIRBuilder::TargetKernelRuntimeAttrs runtimeAttrs;
 
   llvm::SmallVector<llvm::Value *, 4> kernelInput;
   for (size_t i = 0; i < mapVars.size(); ++i) {
@@ -3957,9 +3959,9 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
 
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
-          ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
-          defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds,
-          targetOp.getNowait());
+          ompLoc, isOffloadEntry, /*IsSPMD=*/false, allocaIP, builder.saveIP(),
+          entryInfo, defaultAttrs, runtimeAttrs, kernelInput, genMapInfoCB,
+          bodyCB, argAccessorCB, dds, targetOp.getNowait());
 
   if (failed(handleError(afterIP, opInst)))
     return failure();

>From e2b3ac479a5333988a2e8dcbaff26d7a87ea623c Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Wed, 27 Nov 2024 12:08:46 +0000
Subject: [PATCH 2/2] Address review comments

---
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       |  10 +-
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 104 +++++++-----------
 .../Frontend/OpenMPIRBuilderTest.cpp          |  46 ++++----
 3 files changed, 68 insertions(+), 92 deletions(-)

diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index a85f41e586c514..f0ef58ed3b19ba 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -1387,9 +1387,6 @@ class OpenMPIRBuilder {
 
   /// Supporting functions for Reductions CodeGen.
 private:
-  /// Emit the llvm.used metadata.
-  void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List);
-
   /// Get the id of the current thread on the GPU.
   Value *getGPUThreadID();
 
@@ -2011,6 +2008,13 @@ class OpenMPIRBuilder {
   /// Value.
   GlobalValue *createGlobalFlag(unsigned Value, StringRef Name);
 
+  /// Emit the llvm.used metadata.
+  void emitUsed(StringRef Name, ArrayRef<llvm::WeakTrackingVH> List);
+
+  /// Emit the kernel execution mode.
+  GlobalVariable *emitKernelExecutionMode(StringRef KernelName,
+                                          omp::OMPTgtExecModeFlags Mode);
+
   /// Generate control flow and cleanup for cancellation.
   ///
   /// \param CancelFlag Flag indicating if the cancellation is performed.
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 09f794ccf734b3..73f221c07af746 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -831,6 +831,38 @@ GlobalValue *OpenMPIRBuilder::createGlobalFlag(unsigned Value, StringRef Name) {
   return GV;
 }
 
+void OpenMPIRBuilder::emitUsed(StringRef Name, ArrayRef<WeakTrackingVH> List) {
+  if (List.empty())
+    return;
+
+  // Convert List to what ConstantArray needs.
+  SmallVector<Constant *, 8> UsedArray;
+  UsedArray.resize(List.size());
+  for (unsigned I = 0, E = List.size(); I != E; ++I)
+    UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+        cast<Constant>(&*List[I]), Builder.getPtrTy());
+
+  if (UsedArray.empty())
+    return;
+  ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
+
+  auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
+                                ConstantArray::get(ATy, UsedArray), Name);
+
+  GV->setSection("llvm.metadata");
+}
+
+GlobalVariable *
+OpenMPIRBuilder::emitKernelExecutionMode(StringRef KernelName,
+                                         OMPTgtExecModeFlags Mode) {
+  auto *Int8Ty = Builder.getInt8Ty();
+  auto *GVMode = new GlobalVariable(
+      M, Int8Ty, /*isConstant=*/true, GlobalValue::WeakAnyLinkage,
+      ConstantInt::get(Int8Ty, Mode), Twine(KernelName, "_exec_mode"));
+  GVMode->setVisibility(GlobalVariable::ProtectedVisibility);
+  return GVMode;
+}
+
 Constant *OpenMPIRBuilder::getOrCreateIdent(Constant *SrcLocStr,
                                             uint32_t SrcLocStrSize,
                                             IdentFlag LocFlags,
@@ -2242,28 +2274,6 @@ static OpenMPIRBuilder::InsertPointTy getInsertPointAfterInstr(Instruction *I) {
   return OpenMPIRBuilder::InsertPointTy(I->getParent(), IT);
 }
 
-void OpenMPIRBuilder::emitUsed(StringRef Name,
-                               std::vector<WeakTrackingVH> &List) {
-  if (List.empty())
-    return;
-
-  // Convert List to what ConstantArray needs.
-  SmallVector<Constant *, 8> UsedArray;
-  UsedArray.resize(List.size());
-  for (unsigned I = 0, E = List.size(); I != E; ++I)
-    UsedArray[I] = ConstantExpr::getPointerBitCastOrAddrSpaceCast(
-        cast<Constant>(&*List[I]), Builder.getPtrTy());
-
-  if (UsedArray.empty())
-    return;
-  ArrayType *ATy = ArrayType::get(Builder.getPtrTy(), UsedArray.size());
-
-  auto *GV = new GlobalVariable(M, ATy, false, GlobalValue::AppendingLinkage,
-                                ConstantArray::get(ATy, UsedArray), Name);
-
-  GV->setSection("llvm.metadata");
-}
-
 Value *OpenMPIRBuilder::getGPUThreadID() {
   return Builder.CreateCall(
       getOrCreateRuntimeFunction(M,
@@ -6727,41 +6737,6 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
   return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
 }
 
-static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
-                     Module &M) {
-  if (List.empty())
-    return;
-
-  Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
-
-  // Convert List to what ConstantArray needs.
-  SmallVector<Constant *, 8> UsedArray;
-  UsedArray.reserve(List.size());
-  for (auto Item : List)
-    UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
-        cast<Constant>(&*Item), PtrTy));
-
-  ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
-  auto *GV =
-      new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
-                         llvm::ConstantArray::get(ArrTy, UsedArray), Name);
-
-  GV->setSection("llvm.metadata");
-}
-
-static void
-emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-                  StringRef FunctionName, OMPTgtExecModeFlags Mode,
-                  std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
-  auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
-  auto *GVMode = new llvm::GlobalVariable(
-      OMPBuilder.M, Int8Ty, /*isConstant=*/true,
-      llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
-      Twine(FunctionName, "_exec_mode"));
-  GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
-  LLVMCompilerUsed.emplace_back(GVMode);
-}
-
 static Expected<Function *> createOutlinedFunction(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
@@ -6794,12 +6769,9 @@ static Expected<Function *> createOutlinedFunction(
       Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
 
   if (OMPBuilder.Config.isTargetDevice()) {
-    std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
-    emitExecutionMode(OMPBuilder, Builder, FuncName,
-                      IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
-                             : OMP_TGT_EXEC_MODE_GENERIC,
-                      LLVMCompilerUsed);
-    emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+    Value *ExecMode = OMPBuilder.emitKernelExecutionMode(
+        FuncName, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
+    OMPBuilder.emitUsed("llvm.compiler.used", {ExecMode});
   }
 
   // Save insert point.
@@ -7457,8 +7429,8 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                                 ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
                                 : nullptr;
 
-  for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
-           RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+  for (auto [TeamsVal, TargetVal] : zip_equal(RuntimeAttrs.TeamsThreadLimit,
+                                              RuntimeAttrs.TargetThreadLimit)) {
     Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
     Value *NumThreads = InitMaxThreadsClause(TargetVal);
 
@@ -7518,8 +7490,6 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
     SmallVector<DependData> Dependencies, bool HasNowait) {
-  assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
-         "trip count not expected if IsSPMD=false");
 
   if (!updateToLocation(Loc))
     return InsertPointTy();
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index a8c786b5886afe..e4845256633b9c 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6459,18 +6459,19 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionSPMD) {
     return Builder.saveIP();
   };
 
-  auto SimpleArgAccessorCB =
-      [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
-          llvm::OpenMPIRBuilder::InsertPointTy,
-          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
-        Builder.restoreIP(CodeGenIP);
-        return Builder.saveIP();
-      };
+  auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
+                                 OpenMPIRBuilder::InsertPointTy,
+                                 OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    return Builder.saveIP();
+  };
 
-  llvm::SmallVector<llvm::Value *> Inputs;
-  llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
-  auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
-      -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+  SmallVector<Value *> Inputs;
+  OpenMPIRBuilder::MapInfosTy CombinedInfos;
+  auto GenMapInfoCB =
+      [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & {
+    return CombinedInfos;
+  };
 
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
@@ -6547,19 +6548,20 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDeviceSPMD) {
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
 
   Function *OutlinedFn = nullptr;
-  llvm::SmallVector<llvm::Value *> CapturedArgs;
+  SmallVector<Value *> CapturedArgs;
 
-  auto SimpleArgAccessorCB =
-      [&](llvm::Argument &, llvm::Value *, llvm::Value *&,
-          llvm::OpenMPIRBuilder::InsertPointTy,
-          llvm::OpenMPIRBuilder::InsertPointTy CodeGenIP) {
-        Builder.restoreIP(CodeGenIP);
-        return Builder.saveIP();
-      };
+  auto SimpleArgAccessorCB = [&](Argument &, Value *, Value *&,
+                                 OpenMPIRBuilder::InsertPointTy,
+                                 OpenMPIRBuilder::InsertPointTy CodeGenIP) {
+    Builder.restoreIP(CodeGenIP);
+    return Builder.saveIP();
+  };
 
-  llvm::OpenMPIRBuilder::MapInfosTy CombinedInfos;
-  auto GenMapInfoCB = [&](llvm::OpenMPIRBuilder::InsertPointTy)
-      -> llvm::OpenMPIRBuilder::MapInfosTy & { return CombinedInfos; };
+  OpenMPIRBuilder::MapInfosTy CombinedInfos;
+  auto GenMapInfoCB =
+      [&](OpenMPIRBuilder::InsertPointTy) -> OpenMPIRBuilder::MapInfosTy & {
+    return CombinedInfos;
+  };
 
   auto BodyGenCB = [&](OpenMPIRBuilder::InsertPointTy,
                        OpenMPIRBuilder::InsertPointTy CodeGenIP)



More information about the llvm-branch-commits mailing list