[llvm-branch-commits] [llvm] [mlir] [OMPIRBuilder] Support runtime number of teams and threads, and SPMD mode (PR #116051)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Nov 13 05:41:42 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

This patch introduces a `TargetKernelRuntimeAttrs` structure to hold host- evaluated `num_teams`, `thread_limit`, `num_threads` and trip count values passed to the runtime kernel offloading call.

Additionally, `createTarget` is extended to take an `IsSPMD` flag, used to influence target device code generation.

---

Patch is 31.58 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116051.diff


4 Files Affected:

- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+25-1) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+118-19) 
- (modified) llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp (+271-10) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+6-4) 


``````````diff
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index da450ef5adbc14..a85f41e586c514 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2237,6 +2237,26 @@ class OpenMPIRBuilder {
     int32_t MinThreads = 1;
   };
 
+  /// Container to pass LLVM IR runtime values or constants related to the
+  /// number of teams and threads with which the kernel must be launched, as
+  /// well as the trip count of the SPMD loop, if it is an SPMD kernel. These
+  /// must be defined in the host prior to the call to the kernel launch OpenMP
+  /// RTL function.
+  struct TargetKernelRuntimeAttrs {
+    SmallVector<Value *, 3> MaxTeams = {nullptr};
+    Value *MinTeams = nullptr;
+    SmallVector<Value *, 3> TargetThreadLimit = {nullptr};
+    SmallVector<Value *, 3> TeamsThreadLimit = {nullptr};
+
+    /// 'parallel' construct 'num_threads' clause value, if present and it is a
+    /// target SPMD kernel.
+    Value *MaxThreads = nullptr;
+
+    /// Total number of iterations of the target SPMD kernel or null if it is a
+    /// generic kernel.
+    Value *LoopTripCount = nullptr;
+  };
+
   /// Data structure that contains the needed information to construct the
   /// kernel args vector.
   struct TargetKernelArgs {
@@ -2905,11 +2925,14 @@ class OpenMPIRBuilder {
   ///
   /// \param Loc where the target data construct was encountered.
   /// \param IsOffloadEntry whether it is an offload entry.
+  /// \param IsSPMD whether it is a target SPMD kernel.
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
   /// \param EntryInfo The entry information about the function.
   /// \param DefaultAttrs Structure containing the default numbers of threads
   ///        and teams to launch the kernel with.
+  /// \param RuntimeAttrs Structure containing the runtime numbers of threads
+  ///        and teams to launch the kernel with.
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
@@ -2919,11 +2942,12 @@ class OpenMPIRBuilder {
   // dependency information as passed in the depend clause
   // \param HasNowait Whether the target construct has a `nowait` clause or not.
   InsertPointOrErrorTy createTarget(
-      const LocationDescription &Loc, bool IsOffloadEntry,
+      const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
       TargetRegionEntryInfo &EntryInfo,
       const TargetKernelDefaultAttrs &DefaultAttrs,
+      const TargetKernelRuntimeAttrs &RuntimeAttrs,
       SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
       TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 302d363965c940..f847f60386df85 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6727,8 +6727,43 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
   return getOrCreateRuntimeFunction(M, omp::OMPRTL___kmpc_dispatch_deinit);
 }
 
+static void emitUsed(StringRef Name, std::vector<llvm::WeakTrackingVH> &List,
+                     Module &M) {
+  if (List.empty())
+    return;
+
+  Type *PtrTy = PointerType::get(M.getContext(), /*AddressSpace=*/0);
+
+  // Convert List to what ConstantArray needs.
+  SmallVector<Constant *, 8> UsedArray;
+  UsedArray.reserve(List.size());
+  for (auto Item : List)
+    UsedArray.push_back(ConstantExpr::getPointerBitCastOrAddrSpaceCast(
+        cast<Constant>(&*Item), PtrTy));
+
+  ArrayType *ArrTy = ArrayType::get(PtrTy, UsedArray.size());
+  auto *GV =
+      new GlobalVariable(M, ArrTy, false, llvm::GlobalValue::AppendingLinkage,
+                         llvm::ConstantArray::get(ArrTy, UsedArray), Name);
+
+  GV->setSection("llvm.metadata");
+}
+
+static void
+emitExecutionMode(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+                  StringRef FunctionName, OMPTgtExecModeFlags Mode,
+                  std::vector<llvm::WeakTrackingVH> &LLVMCompilerUsed) {
+  auto *Int8Ty = Type::getInt8Ty(Builder.getContext());
+  auto *GVMode = new llvm::GlobalVariable(
+      OMPBuilder.M, Int8Ty, /*isConstant=*/true,
+      llvm::GlobalValue::WeakAnyLinkage, llvm::ConstantInt::get(Int8Ty, Mode),
+      Twine(FunctionName, "_exec_mode"));
+  GVMode->setVisibility(llvm::GlobalVariable::ProtectedVisibility);
+  LLVMCompilerUsed.emplace_back(GVMode);
+}
+
 static Expected<Function *> createOutlinedFunction(
-    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsSPMD,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
@@ -6758,6 +6793,27 @@ static Expected<Function *> createOutlinedFunction(
   auto Func =
       Function::Create(FuncType, GlobalValue::InternalLinkage, FuncName, M);
 
+  // Forward target-cpu and target-features function attributes from the
+  // original function to the new outlined function.
+  Function *ParentFn = Builder.GetInsertBlock()->getParent();
+
+  auto TargetCpuAttr = ParentFn->getFnAttribute("target-cpu");
+  if (TargetCpuAttr.isStringAttribute())
+    Func->addFnAttr(TargetCpuAttr);
+
+  auto TargetFeaturesAttr = ParentFn->getFnAttribute("target-features");
+  if (TargetFeaturesAttr.isStringAttribute())
+    Func->addFnAttr(TargetFeaturesAttr);
+
+  if (OMPBuilder.Config.isTargetDevice()) {
+    std::vector<llvm::WeakTrackingVH> LLVMCompilerUsed;
+    emitExecutionMode(OMPBuilder, Builder, FuncName,
+                      IsSPMD ? OMP_TGT_EXEC_MODE_SPMD
+                             : OMP_TGT_EXEC_MODE_GENERIC,
+                      LLVMCompilerUsed);
+    emitUsed("llvm.compiler.used", LLVMCompilerUsed, OMPBuilder.M);
+  }
+
   // Save insert point.
   IRBuilder<>::InsertPointGuard IPG(Builder);
   // If there's a DISubprogram associated with current function, then
@@ -6798,7 +6854,7 @@ static Expected<Function *> createOutlinedFunction(
   // Insert target init call in the device compilation pass.
   if (OMPBuilder.Config.isTargetDevice())
     Builder.restoreIP(
-        OMPBuilder.createTargetInit(Builder, /*IsSPMD=*/false, DefaultAttrs));
+        OMPBuilder.createTargetInit(Builder, IsSPMD, DefaultAttrs));
 
   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
 
@@ -6995,7 +7051,7 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
 
 static Error emitTargetOutlinedFunction(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
-    TargetRegionEntryInfo &EntryInfo,
+    bool IsSPMD, TargetRegionEntryInfo &EntryInfo,
     const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
     Function *&OutlinedFn, Constant *&OutlinedFnID,
     SmallVectorImpl<Value *> &Inputs,
@@ -7004,7 +7060,7 @@ static Error emitTargetOutlinedFunction(
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
       [&](StringRef EntryFnName) {
-        return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+        return createOutlinedFunction(OMPBuilder, Builder, IsSPMD, DefaultAttrs,
                                       EntryFnName, Inputs, CBFunc,
                                       ArgAccessorFuncCB);
       };
@@ -7304,6 +7360,7 @@ static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                OpenMPIRBuilder::InsertPointTy AllocaIP,
                const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+               const OpenMPIRBuilder::TargetKernelRuntimeAttrs &RuntimeAttrs,
                Function *OutlinedFn, Constant *OutlinedFnID,
                SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
@@ -7385,11 +7442,43 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
                                          /*ForEndCall=*/false);
 
   SmallVector<Value *, 3> NumTeamsC;
+  for (auto [DefaultVal, RuntimeVal] :
+       zip_equal(DefaultAttrs.MaxTeams, RuntimeAttrs.MaxTeams))
+    NumTeamsC.push_back(RuntimeVal ? RuntimeVal : Builder.getInt32(DefaultVal));
+
+  // Calculate number of threads: 0 if no clauses specified, otherwise it is the
+  // minimum between optional THREAD_LIMIT and NUM_THREADS clauses.
+  auto InitMaxThreadsClause = [&Builder](Value *Clause) {
+    if (Clause)
+      Clause = Builder.CreateIntCast(Clause, Builder.getInt32Ty(),
+                                     /*isSigned=*/false);
+    return Clause;
+  };
+  auto CombineMaxThreadsClauses = [&Builder](Value *Clause, Value *&Result) {
+    if (Clause)
+      Result = Result
+                   ? Builder.CreateSelect(Builder.CreateICmpULT(Result, Clause),
+                                          Result, Clause)
+                   : Clause;
+  };
+
+  // If a multi-dimensional THREAD_LIMIT is set, it is the OMPX_BARE case, so
+  // the NUM_THREADS clause is overriden by THREAD_LIMIT.
   SmallVector<Value *, 3> NumThreadsC;
-  for (auto V : DefaultAttrs.MaxTeams)
-    NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
-  for (auto V : DefaultAttrs.MaxThreads)
-    NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+  Value *MaxThreadsClause = RuntimeAttrs.TeamsThreadLimit.size() == 1
+                                ? InitMaxThreadsClause(RuntimeAttrs.MaxThreads)
+                                : nullptr;
+
+  for (auto [TeamsVal, TargetVal] : llvm::zip_equal(
+           RuntimeAttrs.TeamsThreadLimit, RuntimeAttrs.TargetThreadLimit)) {
+    Value *TeamsThreadLimitClause = InitMaxThreadsClause(TeamsVal);
+    Value *NumThreads = InitMaxThreadsClause(TargetVal);
+
+    CombineMaxThreadsClauses(TeamsThreadLimitClause, NumThreads);
+    CombineMaxThreadsClauses(MaxThreadsClause, NumThreads);
+
+    NumThreadsC.push_back(NumThreads ? NumThreads : Builder.getInt32(0));
+  }
 
   unsigned NumTargetItems = Info.NumberOfPtrs;
   // TODO: Use correct device ID
@@ -7398,14 +7487,19 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
   Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
   Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
                                              llvm::omp::IdentFlag(0), 0);
-  // TODO: Use correct NumIterations
-  Value *NumIterations = Builder.getInt64(0);
+
+  Value *TripCount = RuntimeAttrs.LoopTripCount
+                         ? Builder.CreateIntCast(RuntimeAttrs.LoopTripCount,
+                                                 Builder.getInt64Ty(),
+                                                 /*isSigned=*/false)
+                         : Builder.getInt64(0);
+
   // TODO: Use correct DynCGGroupMem
   Value *DynCGGroupMem = Builder.getInt32(0);
 
-  KArgs = OpenMPIRBuilder::TargetKernelArgs(
-      NumTargetItems, RTArgs, NumIterations, NumTeamsC, NumThreadsC,
-      DynCGGroupMem, HasNoWait);
+  KArgs = OpenMPIRBuilder::TargetKernelArgs(NumTargetItems, RTArgs, TripCount,
+                                            NumTeamsC, NumThreadsC,
+                                            DynCGGroupMem, HasNoWait);
 
   // The presence of certain clauses on the target directive require the
   // explicit generation of the target task.
@@ -7427,13 +7521,17 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 }
 
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
-    const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
-    InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
+    const LocationDescription &Loc, bool IsOffloadEntry, bool IsSPMD,
+    InsertPointTy AllocaIP, InsertPointTy CodeGenIP,
+    TargetRegionEntryInfo &EntryInfo,
     const TargetKernelDefaultAttrs &DefaultAttrs,
+    const TargetKernelRuntimeAttrs &RuntimeAttrs,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
     SmallVector<DependData> Dependencies, bool HasNowait) {
+  assert((!RuntimeAttrs.LoopTripCount || IsSPMD) &&
+         "trip count not expected if IsSPMD=false");
 
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -7446,16 +7544,17 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // the target region itself is generated using the callbacks CBFunc
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
-          *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
-          OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
+          *this, Builder, IsOffloadEntry, IsSPMD, EntryInfo, DefaultAttrs,
+          OutlinedFn, OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
-                   OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, RuntimeAttrs,
+                   OutlinedFn, OutlinedFnID, Args, GenMapInfoCB, Dependencies,
+                   HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index b0688d6215e42d..63be7e775b83c9 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6122,8 +6122,10 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   OpenMPIRBuilderConfig Config(false, false, false, false, false, false, false);
   OMPBuilder.setConfig(Config);
   F->setName("func");
+  F->addFnAttr("target-cpu", "x86-64");
+  F->addFnAttr("target-features", "+mmx,+sse");
   IRBuilder<> Builder(BB);
-  auto Int32Ty = Builder.getInt32Ty();
+  auto *Int32Ty = Builder.getInt32Ty();
 
   AllocaInst *APtr = Builder.CreateAlloca(Int32Ty, nullptr, "a_ptr");
   AllocaInst *BPtr = Builder.CreateAlloca(Int32Ty, nullptr, "b_ptr");
@@ -6183,11 +6185,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
-      /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
-  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP =
-      OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
-                              Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
-                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      /*MaxTeams=*/{10}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
+  RuntimeAttrs.TargetThreadLimit[0] = Builder.getInt32(20);
+  RuntimeAttrs.TeamsThreadLimit[0] = Builder.getInt32(30);
+  RuntimeAttrs.MaxThreads = Builder.getInt32(40);
+  OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
+      OmpLoc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, Builder.saveIP(),
+      Builder.saveIP(), EntryInfo, DefaultAttrs, RuntimeAttrs, Inputs,
+      GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
   OMPBuilder.finalize();
@@ -6207,6 +6213,43 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   StringRef FunctionName = KernelLaunchFunc->getName();
   EXPECT_TRUE(FunctionName.starts_with("__tgt_target_kernel"));
 
+  // Check num_teams and num_threads in call arguments
+  EXPECT_TRUE(Call->arg_size() >= 4);
+  Value *NumTeamsArg = Call->getArgOperand(2);
+  EXPECT_TRUE(isa<ConstantInt>(NumTeamsArg));
+  EXPECT_EQ(10U, cast<ConstantInt>(NumTeamsArg)->getZExtValue());
+  Value *NumThreadsArg = Call->getArgOperand(3);
+  EXPECT_TRUE(isa<ConstantInt>(NumThreadsArg));
+  EXPECT_EQ(20U, cast<ConstantInt>(NumThreadsArg)->getZExtValue());
+
+  // Check num_teams and num_threads kernel arguments (use number 5 starting
+  // from the end and counting the call to __tgt_target_kernel as the first use)
+  Value *KernelArgs = Call->getArgOperand(Call->arg_size() - 1);
+  EXPECT_TRUE(KernelArgs->getNumUses() >= 4);
+  Value *NumTeamsGetElemPtr = *std::next(KernelArgs->user_begin(), 3);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumTeamsGetElemPtr));
+  Value *NumTeamsStore = NumTeamsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumTeamsStore));
+  Value *NumTeamsStoreArg = cast<StoreInst>(NumTeamsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumTeamsStoreArg));
+  auto *NumTeamsStoreValue = cast<ConstantDataSequential>(NumTeamsStoreArg);
+  EXPECT_EQ(3U, NumTeamsStoreValue->getNumElements());
+  EXPECT_EQ(10U, NumTeamsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumTeamsStoreValue->getElementAsInteger(2));
+  Value *NumThreadsGetElemPtr = *std::next(KernelArgs->user_begin(), 2);
+  EXPECT_TRUE(isa<GetElementPtrInst>(NumThreadsGetElemPtr));
+  Value *NumThreadsStore = NumThreadsGetElemPtr->getUniqueUndroppableUser();
+  EXPECT_TRUE(isa<StoreInst>(NumThreadsStore));
+  Value *NumThreadsStoreArg =
+      cast<StoreInst>(NumThreadsStore)->getValueOperand();
+  EXPECT_TRUE(isa<ConstantDataSequential>(NumThreadsStoreArg));
+  auto *NumThreadsStoreValue = cast<ConstantDataSequential>(NumThreadsStoreArg);
+  EXPECT_EQ(3U, NumThreadsStoreValue->getNumElements());
+  EXPECT_EQ(20U, NumThreadsStoreValue->getElementAsInteger(0));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(1));
+  EXPECT_EQ(0U, NumThreadsStoreValue->getElementAsInteger(2));
+
   // Check the fallback call
   BasicBlock *FallbackBlock = Branch->getSuccessor(0);
   Iter = FallbackBlock->rbegin();
@@ -6228,6 +6271,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
   StringRef FunctionName2 = OutlinedFunc->getName();
   EXPECT_TRUE(FunctionName2.starts_with("__omp_offloading"));
 
+  // Check that target-cpu and target-features were propagated to the outlined
+  // function
+  EXPECT_EQ(OutlinedFunc->getFnAttribute("target-cpu"),
+            F->getFnAttribute("target-cpu"));
+  EXPECT_EQ(OutlinedFunc->getFnAttribute("target-features"),
+            F->getFnAttribute("target-features"));
+
   EXPECT_FALSE(verifyModule(*M, &errs()));
 }
 
@@ -6238,6 +6288,8 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   OMPBuilder.initialize();
 
   F->setName("func");
+  F->addFnAttr("target-cpu", "gfx90a");
+  F->addFnAttr("target-features", "+gfx9-insts,+wavefrontsize64");
   IRBuilder<> Builder(BB);
   OpenMPIRBuilder::LocationDescription Loc({Builder.saveIP(), DL});
 
@@ -6297,9 +6349,11 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
 
   OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
       /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0}, /*MinThreads=*/0};
+  OpenMPIRBuilder::TargetKernelRuntimeAttrs RuntimeAttrs;
   OpenMPIRBuilder::InsertPointOrErrorTy AfterIP = OMPBuilder.createTarget(
-      Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP, EntryInfo, DefaultAttrs,
-      CapturedArgs, GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB);
+      Loc, /*IsOffloadEntry=*/true, /*IsSPMD=*/false, EntryIP, EntryIP,
+      EntryInfo, DefaultAttrs, RuntimeAttrs, CapturedArgs, GenMapInfoCB,
+      BodyGenCB, SimpleArgAccessorCB);
   assert(AfterIP && "unexpected error");
   Builder.restoreIP(*AfterIP);
 
@@ -6312,6 +6366,13 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
   Function *OutlinedFn = TargetStore->getFunction();
   EXPECT_NE(F, OutlinedFn);
 
+  // Check that target-cpu and target-features were propagated to the outlined
+  // function
+  EXPECT_EQ(OutlinedFn->getFnAttribute("target-cpu"),
+       ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/116051


More information about the llvm-branch-commits mailing list