[clang] [llvm] [mlir] [OMPIRBuilder] Introduce struct to hold default kernel teams/threads (PR #116050)

Sergio Afonso via llvm-commits llvm-commits at lists.llvm.org
Tue Jan 14 02:28:18 PST 2025


https://github.com/skatrak updated https://github.com/llvm/llvm-project/pull/116050

>From 219d430cf7697eb2c4bcb5832695571ff472e0e2 Mon Sep 17 00:00:00 2001
From: Sergio Afonso <safonsof at amd.com>
Date: Fri, 8 Nov 2024 15:46:48 +0000
Subject: [PATCH] [OMPIRBuilder] Introduce struct to hold default kernel
 teams/threads

This patch introduces the `OpenMPIRBuilder::TargetKernelDefaultAttrs` structure
used to simplify passing default and constant values for number of teams and
threads, and possibly other target kernel-related information in the future.

This is used to forward values passed to `createTarget` to `createTargetInit`,
which previously used a default unrelated set of values.
---
 clang/lib/CodeGen/CGOpenMPRuntime.cpp         | 13 ++--
 clang/lib/CodeGen/CGOpenMPRuntime.h           |  9 +--
 clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp      | 10 +--
 .../llvm/Frontend/OpenMP/OMPIRBuilder.h       | 41 ++++++----
 llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp     | 75 +++++++++++--------
 .../Frontend/OpenMPIRBuilderTest.cpp          | 22 ++++--
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 12 +--
 .../LLVMIR/omptarget-region-device-llvm.mlir  |  2 +-
 8 files changed, 106 insertions(+), 78 deletions(-)

diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 1868b57cea9039..244e3066f8fe41 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -5880,10 +5880,13 @@ void CGOpenMPRuntime::emitUsesAllocatorsFini(CodeGenFunction &CGF,
 
 void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
     const OMPExecutableDirective &D, CodeGenFunction &CGF,
-    int32_t &MinThreadsVal, int32_t &MaxThreadsVal, int32_t &MinTeamsVal,
-    int32_t &MaxTeamsVal) {
+    llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+  assert(Attrs.MaxTeams.size() == 1 && Attrs.MaxThreads.size() == 1 &&
+         "invalid default attrs structure");
+  int32_t &MaxTeamsVal = Attrs.MaxTeams.front();
+  int32_t &MaxThreadsVal = Attrs.MaxThreads.front();
 
-  getNumTeamsExprForTargetDirective(CGF, D, MinTeamsVal, MaxTeamsVal);
+  getNumTeamsExprForTargetDirective(CGF, D, Attrs.MinTeams, MaxTeamsVal);
   getNumThreadsExprForTargetDirective(CGF, D, MaxThreadsVal,
                                       /*UpperBoundOnly=*/true);
 
@@ -5901,12 +5904,12 @@ void CGOpenMPRuntime::computeMinAndMaxThreadsAndTeams(
       else
         continue;
 
-      MinThreadsVal = std::max(MinThreadsVal, AttrMinThreadsVal);
+      Attrs.MinThreads = std::max(Attrs.MinThreads, AttrMinThreadsVal);
       if (AttrMaxThreadsVal > 0)
         MaxThreadsVal = MaxThreadsVal > 0
                             ? std::min(MaxThreadsVal, AttrMaxThreadsVal)
                             : AttrMaxThreadsVal;
-      MinTeamsVal = std::max(MinTeamsVal, AttrMinBlocksVal);
+      Attrs.MinTeams = std::max(Attrs.MinTeams, AttrMinBlocksVal);
       if (AttrMaxBlocksVal > 0)
         MaxTeamsVal = MaxTeamsVal > 0 ? std::min(MaxTeamsVal, AttrMaxBlocksVal)
                                       : AttrMaxBlocksVal;
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.h b/clang/lib/CodeGen/CGOpenMPRuntime.h
index 8ab5ee70a19fa2..3791bb71592350 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.h
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.h
@@ -313,12 +313,9 @@ class CGOpenMPRuntime {
   llvm::OpenMPIRBuilder OMPBuilder;
 
   /// Helper to determine the min/max number of threads/teams for \p D.
-  void computeMinAndMaxThreadsAndTeams(const OMPExecutableDirective &D,
-                                       CodeGenFunction &CGF,
-                                       int32_t &MinThreadsVal,
-                                       int32_t &MaxThreadsVal,
-                                       int32_t &MinTeamsVal,
-                                       int32_t &MaxTeamsVal);
+  void computeMinAndMaxThreadsAndTeams(
+      const OMPExecutableDirective &D, CodeGenFunction &CGF,
+      llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
 
   /// Helper to emit outlined function for 'target' directive.
   /// \param D Directive to emit.
diff --git a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
index dbb19f2a8d825a..81993dafae2b03 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntimeGPU.cpp
@@ -744,14 +744,12 @@ void CGOpenMPRuntimeGPU::emitNonSPMDKernel(const OMPExecutableDirective &D,
 void CGOpenMPRuntimeGPU::emitKernelInit(const OMPExecutableDirective &D,
                                         CodeGenFunction &CGF,
                                         EntryFunctionState &EST, bool IsSPMD) {
-  int32_t MinThreadsVal = 1, MaxThreadsVal = -1, MinTeamsVal = 1,
-          MaxTeamsVal = -1;
-  computeMinAndMaxThreadsAndTeams(D, CGF, MinThreadsVal, MaxThreadsVal,
-                                  MinTeamsVal, MaxTeamsVal);
+  llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs Attrs;
+  Attrs.IsSPMD = IsSPMD;
+  computeMinAndMaxThreadsAndTeams(D, CGF, Attrs);
 
   CGBuilderTy &Bld = CGF.Builder;
-  Bld.restoreIP(OMPBuilder.createTargetInit(
-      Bld, IsSPMD, MinThreadsVal, MaxThreadsVal, MinTeamsVal, MaxTeamsVal));
+  Bld.restoreIP(OMPBuilder.createTargetInit(Bld, Attrs));
   if (!IsSPMD)
     emitGenericVarsProlog(CGF, EST.Loc);
 }
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 4ce47b1c05d9b0..8ca3bc08b5ad49 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2225,6 +2225,21 @@ class OpenMPIRBuilder {
           MapNamesArray(MapNamesArray) {}
   };
 
+  /// Container to pass the default attributes with which a kernel must be
+  /// launched, used to set kernel attributes and populate associated static
+  /// structures.
+  ///
+  /// For max values, < 0 means unset, == 0 means set but unknown at compile
+  /// time. The number of max values will be 1 except for the case where
+  /// ompx_bare is set.
+  struct TargetKernelDefaultAttrs {
+    bool IsSPMD = false;
+    SmallVector<int32_t, 3> MaxTeams = {-1};
+    int32_t MinTeams = 1;
+    SmallVector<int32_t, 3> MaxThreads = {-1};
+    int32_t MinThreads = 1;
+  };
+
   /// Data structure that contains the needed information to construct the
   /// kernel args vector.
   struct TargetKernelArgs {
@@ -2727,16 +2742,11 @@ class OpenMPIRBuilder {
   /// Create a runtime call for kmpc_target_init
   ///
   /// \param Loc The insert and source location description.
-  /// \param IsSPMD Flag to indicate if the kernel is an SPMD kernel or not.
-  /// \param MinThreads Minimal number of threads, or 0.
-  /// \param MaxThreads Maximal number of threads, or 0.
-  /// \param MinTeams Minimal number of teams, or 0.
-  /// \param MaxTeams Maximal number of teams, or 0.
-  InsertPointTy createTargetInit(const LocationDescription &Loc, bool IsSPMD,
-                                 int32_t MinThreadsVal = 0,
-                                 int32_t MaxThreadsVal = 0,
-                                 int32_t MinTeamsVal = 0,
-                                 int32_t MaxTeamsVal = 0);
+  /// \param Attrs Structure containing the default attributes, including
+  ///        numbers of threads and teams to launch the kernel with.
+  InsertPointTy createTargetInit(
+      const LocationDescription &Loc,
+      const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs);
 
   /// Create a runtime call for kmpc_target_deinit
   ///
@@ -2961,8 +2971,8 @@ class OpenMPIRBuilder {
   /// \param CodeGenIP The insertion point where the call to the outlined
   /// function should be emitted.
   /// \param EntryInfo The entry information about the function.
-  /// \param NumTeams Number of teams specified in the num_teams clause.
-  /// \param NumThreads Number of teams specified in the thread_limit clause.
+  /// \param DefaultAttrs Structure containing the default numbers of threads
+  ///        and teams to launch the kernel with.
   /// \param Inputs The input values to the region that will be passed.
   /// as arguments to the outlined function.
   /// \param BodyGenCB Callback that will generate the region code.
@@ -2975,9 +2985,10 @@ class OpenMPIRBuilder {
       const LocationDescription &Loc, bool IsOffloadEntry,
       OpenMPIRBuilder::InsertPointTy AllocaIP,
       OpenMPIRBuilder::InsertPointTy CodeGenIP,
-      TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
-      ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
-      GenMapInfoCallbackTy GenMapInfoCB, TargetBodyGenCallbackTy BodyGenCB,
+      TargetRegionEntryInfo &EntryInfo,
+      const TargetKernelDefaultAttrs &DefaultAttrs,
+      SmallVectorImpl<Value *> &Inputs, GenMapInfoCallbackTy GenMapInfoCB,
+      TargetBodyGenCallbackTy BodyGenCB,
       TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
       SmallVector<DependData> Dependencies = {}, bool HasNowait = false);
 
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index 2b57a8dce3de5a..df9b35ddd80ca4 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -6119,10 +6119,12 @@ CallInst *OpenMPIRBuilder::createCachedThreadPrivate(
   return Builder.CreateCall(Fn, Args);
 }
 
-OpenMPIRBuilder::InsertPointTy
-OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
-                                  int32_t MinThreadsVal, int32_t MaxThreadsVal,
-                                  int32_t MinTeamsVal, int32_t MaxTeamsVal) {
+OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTargetInit(
+    const LocationDescription &Loc,
+    const llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs &Attrs) {
+  assert(!Attrs.MaxThreads.empty() && !Attrs.MaxTeams.empty() &&
+         "expected num_threads and num_teams to be specified");
+
   if (!updateToLocation(Loc))
     return Loc.IP;
 
@@ -6130,8 +6132,9 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
   Constant *SrcLocStr = getOrCreateSrcLocStr(Loc, SrcLocStrSize);
   Constant *Ident = getOrCreateIdent(SrcLocStr, SrcLocStrSize);
   Constant *IsSPMDVal = ConstantInt::getSigned(
-      Int8, IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
-  Constant *UseGenericStateMachineVal = ConstantInt::getSigned(Int8, !IsSPMD);
+      Int8, Attrs.IsSPMD ? OMP_TGT_EXEC_MODE_SPMD : OMP_TGT_EXEC_MODE_GENERIC);
+  Constant *UseGenericStateMachineVal =
+      ConstantInt::getSigned(Int8, !Attrs.IsSPMD);
   Constant *MayUseNestedParallelismVal = ConstantInt::getSigned(Int8, true);
   Constant *DebugIndentionLevelVal = ConstantInt::getSigned(Int16, 0);
 
@@ -6149,21 +6152,23 @@ OpenMPIRBuilder::createTargetInit(const LocationDescription &Loc, bool IsSPMD,
 
   // Manifest the launch configuration in the metadata matching the kernel
   // environment.
-  if (MinTeamsVal > 1 || MaxTeamsVal > 0)
-    writeTeamsForKernel(T, *Kernel, MinTeamsVal, MaxTeamsVal);
+  if (Attrs.MinTeams > 1 || Attrs.MaxTeams.front() > 0)
+    writeTeamsForKernel(T, *Kernel, Attrs.MinTeams, Attrs.MaxTeams.front());
 
-  // For max values, < 0 means unset, == 0 means set but unknown.
+  // If MaxThreads not set, select the maximum between the default workgroup
+  // size and the MinThreads value.
+  int32_t MaxThreadsVal = Attrs.MaxThreads.front();
   if (MaxThreadsVal < 0)
     MaxThreadsVal = std::max(
-        int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), MinThreadsVal);
+        int32_t(getGridValue(T, Kernel).GV_Default_WG_Size), Attrs.MinThreads);
 
   if (MaxThreadsVal > 0)
-    writeThreadBoundsForKernel(T, *Kernel, MinThreadsVal, MaxThreadsVal);
+    writeThreadBoundsForKernel(T, *Kernel, Attrs.MinThreads, MaxThreadsVal);
 
-  Constant *MinThreads = ConstantInt::getSigned(Int32, MinThreadsVal);
+  Constant *MinThreads = ConstantInt::getSigned(Int32, Attrs.MinThreads);
   Constant *MaxThreads = ConstantInt::getSigned(Int32, MaxThreadsVal);
-  Constant *MinTeams = ConstantInt::getSigned(Int32, MinTeamsVal);
-  Constant *MaxTeams = ConstantInt::getSigned(Int32, MaxTeamsVal);
+  Constant *MinTeams = ConstantInt::getSigned(Int32, Attrs.MinTeams);
+  Constant *MaxTeams = ConstantInt::getSigned(Int32, Attrs.MaxTeams.front());
   Constant *ReductionDataSize = ConstantInt::getSigned(Int32, 0);
   Constant *ReductionBufferLength = ConstantInt::getSigned(Int32, 0);
 
@@ -6730,8 +6735,9 @@ FunctionCallee OpenMPIRBuilder::createDispatchDeinitFunction() {
 }
 
 static Expected<Function *> createOutlinedFunction(
-    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, StringRef FuncName,
-    SmallVectorImpl<Value *> &Inputs,
+    OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
+    const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+    StringRef FuncName, SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
   SmallVector<Type *> ParameterTypes;
@@ -6798,7 +6804,7 @@ static Expected<Function *> createOutlinedFunction(
 
   // Insert target init call in the device compilation pass.
   if (OMPBuilder.Config.isTargetDevice())
-    Builder.restoreIP(OMPBuilder.createTargetInit(Builder, /*IsSPMD*/ false));
+    Builder.restoreIP(OMPBuilder.createTargetInit(Builder, DefaultAttrs));
 
   BasicBlock *UserCodeEntryBB = Builder.GetInsertBlock();
 
@@ -6997,16 +7003,18 @@ static Function *emitTargetTaskProxyFunction(OpenMPIRBuilder &OMPBuilder,
 
 static Error emitTargetOutlinedFunction(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder, bool IsOffloadEntry,
-    TargetRegionEntryInfo &EntryInfo, Function *&OutlinedFn,
-    Constant *&OutlinedFnID, SmallVectorImpl<Value *> &Inputs,
+    TargetRegionEntryInfo &EntryInfo,
+    const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+    Function *&OutlinedFn, Constant *&OutlinedFnID,
+    SmallVectorImpl<Value *> &Inputs,
     OpenMPIRBuilder::TargetBodyGenCallbackTy &CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy &ArgAccessorFuncCB) {
 
   OpenMPIRBuilder::FunctionGenCallback &&GenerateOutlinedFunction =
-      [&OMPBuilder, &Builder, &Inputs, &CBFunc,
-       &ArgAccessorFuncCB](StringRef EntryFnName) {
-        return createOutlinedFunction(OMPBuilder, Builder, EntryFnName, Inputs,
-                                      CBFunc, ArgAccessorFuncCB);
+      [&](StringRef EntryFnName) {
+        return createOutlinedFunction(OMPBuilder, Builder, DefaultAttrs,
+                                      EntryFnName, Inputs, CBFunc,
+                                      ArgAccessorFuncCB);
       };
 
   return OMPBuilder.emitTargetRegionFunction(
@@ -7302,9 +7310,10 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
 
 static void
 emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
-               OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
-               Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
-               ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               const OpenMPIRBuilder::TargetKernelDefaultAttrs &DefaultAttrs,
+               Function *OutlinedFn, Constant *OutlinedFnID,
+               SmallVectorImpl<Value *> &Args,
                OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
                SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {},
                bool HasNoWait = false) {
@@ -7385,9 +7394,9 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 
   SmallVector<Value *, 3> NumTeamsC;
   SmallVector<Value *, 3> NumThreadsC;
-  for (auto V : NumTeams)
+  for (auto V : DefaultAttrs.MaxTeams)
     NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
-  for (auto V : NumThreads)
+  for (auto V : DefaultAttrs.MaxThreads)
     NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
 
   unsigned NumTargetItems = Info.NumberOfPtrs;
@@ -7428,7 +7437,7 @@ emitTargetCall(OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
 OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
-    ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
+    const TargetKernelDefaultAttrs &DefaultAttrs,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
@@ -7445,16 +7454,16 @@ OpenMPIRBuilder::InsertPointOrErrorTy OpenMPIRBuilder::createTarget(
   // the target region itself is generated using the callbacks CBFunc
   // and ArgAccessorFuncCB
   if (Error Err = emitTargetOutlinedFunction(
-          *this, Builder, IsOffloadEntry, EntryInfo, OutlinedFn, OutlinedFnID,
-          Args, CBFunc, ArgAccessorFuncCB))
+          *this, Builder, IsOffloadEntry, EntryInfo, DefaultAttrs, OutlinedFn,
+          OutlinedFnID, Args, CBFunc, ArgAccessorFuncCB))
     return Err;
 
   // If we are not on the target device, then we need to generate code
   // to make a remote call (offload) to the previously outlined function
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
-    emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
-                   NumThreads, Args, GenMapInfoCB, Dependencies, HasNowait);
+    emitTargetCall(*this, Builder, AllocaIP, DefaultAttrs, OutlinedFn,
+                   OutlinedFnID, Args, GenMapInfoCB, Dependencies, HasNowait);
   return Builder.saveIP();
 }
 
diff --git a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
index cdca725b147436..04ecd7ef327d5a 100644
--- a/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
+++ b/llvm/unittests/Frontend/OpenMPIRBuilderTest.cpp
@@ -6229,10 +6229,14 @@ TEST_F(OpenMPIRBuilderTest, TargetRegion) {
 
   TargetRegionEntryInfo EntryInfo("func", 42, 4711, 17);
   OpenMPIRBuilder::LocationDescription OmpLoc({Builder.saveIP(), DL});
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
+      /*MinThreads=*/0};
+
   ASSERT_EXPECTED_INIT(
       OpenMPIRBuilder::InsertPointTy, AfterIP,
       OMPBuilder.createTarget(OmpLoc, /*IsOffloadEntry=*/true, Builder.saveIP(),
-                              Builder.saveIP(), EntryInfo, -1, 0, Inputs,
+                              Builder.saveIP(), EntryInfo, DefaultAttrs, Inputs,
                               GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
   OMPBuilder.finalize();
@@ -6339,13 +6343,15 @@ TEST_F(OpenMPIRBuilderTest, TargetRegionDevice) {
                                    F->getEntryBlock().getFirstInsertionPt());
   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
                                   /*Line=*/3, /*Count=*/0);
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
+      /*MinThreads=*/0};
 
   ASSERT_EXPECTED_INIT(
       OpenMPIRBuilder::InsertPointTy, AfterIP,
       OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, /*NumTeams=*/-1,
-                              /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
-                              BodyGenCB, SimpleArgAccessorCB));
+                              EntryInfo, DefaultAttrs, CapturedArgs,
+                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   Builder.CreateRetVoid();
@@ -6496,13 +6502,15 @@ TEST_F(OpenMPIRBuilderTest, ConstantAllocaRaise) {
                                    F->getEntryBlock().getFirstInsertionPt());
   TargetRegionEntryInfo EntryInfo("parent", /*DeviceID=*/1, /*FileID=*/2,
                                   /*Line=*/3, /*Count=*/0);
+  OpenMPIRBuilder::TargetKernelDefaultAttrs DefaultAttrs = {
+      /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
+      /*MinThreads=*/0};
 
   ASSERT_EXPECTED_INIT(
       OpenMPIRBuilder::InsertPointTy, AfterIP,
       OMPBuilder.createTarget(Loc, /*IsOffloadEntry=*/true, EntryIP, EntryIP,
-                              EntryInfo, /*NumTeams=*/-1,
-                              /*NumThreads=*/0, CapturedArgs, GenMapInfoCB,
-                              BodyGenCB, SimpleArgAccessorCB));
+                              EntryInfo, DefaultAttrs, CapturedArgs,
+                              GenMapInfoCB, BodyGenCB, SimpleArgAccessorCB));
   Builder.restoreIP(AfterIP);
 
   Builder.CreateRetVoid();
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 35a3750e02a665..25b0ffe2ced6de 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -4084,9 +4084,6 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   if (!getTargetEntryUniqueInfo(entryInfo, targetOp, parentName))
     return failure();
 
-  int32_t defaultValTeams = -1;
-  int32_t defaultValThreads = 0;
-
   MapInfoData mapData;
   collectMapDataFromMapOperands(mapData, mapVars, moduleTranslation, dl,
                                 builder);
@@ -4118,6 +4115,11 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
                                         allocaIP, codeGenIP);
   };
 
+  // TODO: Populate default attributes based on the construct and clauses.
+  llvm::OpenMPIRBuilder::TargetKernelDefaultAttrs defaultAttrs = {
+      /*IsSPMD=*/false, /*MaxTeams=*/{-1}, /*MinTeams=*/0, /*MaxThreads=*/{0},
+      /*MinThreads=*/0};
+
   llvm::SmallVector<llvm::Value *, 4> kernelInput;
   for (size_t i = 0; i < mapVars.size(); ++i) {
     // declare target arguments are not passed to kernels as arguments
@@ -4141,8 +4143,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
   llvm::OpenMPIRBuilder::InsertPointOrErrorTy afterIP =
       moduleTranslation.getOpenMPBuilder()->createTarget(
           ompLoc, isOffloadEntry, allocaIP, builder.saveIP(), entryInfo,
-          defaultValTeams, defaultValThreads, kernelInput, genMapInfoCB, bodyCB,
-          argAccessorCB, dds, targetOp.getNowait());
+          defaultAttrs, kernelInput, genMapInfoCB, bodyCB, argAccessorCB, dds,
+          targetOp.getNowait());
 
   if (failed(handleError(afterIP, opInst)))
     return failure();
diff --git a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
index 8993c0e85c5dea..fa32a3030108d8 100644
--- a/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
+++ b/mlir/test/Target/LLVMIR/omptarget-region-device-llvm.mlir
@@ -29,7 +29,7 @@ module attributes {omp.is_target_device = true} {
 // CHECK:      @[[SRC_LOC:.*]] = private unnamed_addr constant [23 x i8] c"{{[^"]*}}", align 1
 // CHECK:      @[[IDENT:.*]] = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @[[SRC_LOC]] }, align 8
 // CHECK:      @[[DYNA_ENV:.*]] = weak_odr protected global %struct.DynamicEnvironmentTy zeroinitializer
-// CHECK:      @[[KERNEL_ENV:.*]] = weak_odr protected constant %struct.KernelEnvironmentTy { %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 1, i32 0, i32 0, i32 0, i32 0, i32 0, i32 0 }, ptr @[[IDENT]], ptr @[[DYNA_ENV]] }
+// CHECK:      @[[KERNEL_ENV:.*]] = weak_odr protected constant %struct.KernelEnvironmentTy { %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 1, i32 0, i32 0, i32 0, i32 -1, i32 0, i32 0 }, ptr @[[IDENT]], ptr @[[DYNA_ENV]] }
 // CHECK:      define weak_odr protected void @__omp_offloading_{{[^_]+}}_{{[^_]+}}_omp_target_region__l{{[0-9]+}}(ptr %[[DYN_PTR:.*]], ptr %[[ADDR_A:.*]], ptr %[[ADDR_B:.*]], ptr %[[ADDR_C:.*]])
 // CHECK:        %[[TMP_A:.*]] = alloca ptr, align 8
 // CHECK:        store ptr %[[ADDR_A]], ptr %[[TMP_A]], align 8



More information about the llvm-commits mailing list