[llvm-branch-commits] [clang] [llvm] [Clang][OMPX] Add the code generation for multi-dim `thread_limit` clause (PR #102717)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Aug 9 20:35:35 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Shilei Tian (shiltian)

<details>
<summary>Changes</summary>



---
Full diff: https://github.com/llvm/llvm-project/pull/102717.diff


4 Files Affected:

- (modified) clang/lib/CodeGen/CGOpenMPRuntime.cpp (+17-12) 
- (modified) clang/test/OpenMP/target_teams_codegen.cpp (+6-6) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+13-13) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+18-13) 


``````````diff
diff --git a/clang/lib/CodeGen/CGOpenMPRuntime.cpp b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
index 8c5e4aa9c037e2..6c0c8646898cc6 100644
--- a/clang/lib/CodeGen/CGOpenMPRuntime.cpp
+++ b/clang/lib/CodeGen/CGOpenMPRuntime.cpp
@@ -9588,15 +9588,17 @@ static void genMapInfo(const OMPExecutableDirective &D, CodeGenFunction &CGF,
   genMapInfo(MEHandler, CGF, CombinedInfo, OMPBuilder, MappedVarSet);
 }
 
-static void emitNumTeamsForBareTargetDirective(
+template <typename ClauseTy>
+static void emitClauseForBareTargetDirective(
     CodeGenFunction &CGF, const OMPExecutableDirective &D,
-    llvm::SmallVectorImpl<llvm::Value *> &NumTeams) {
-  const auto *C = D.getSingleClause<OMPNumTeamsClause>();
-  assert(!C->varlist_empty() && "ompx_bare requires explicit num_teams");
-  CodeGenFunction::RunCleanupsScope NumTeamsScope(CGF);
-  for (auto *E : C->getNumTeams()) {
+    llvm::SmallVectorImpl<llvm::Value *> &Valuess) {
+  const auto *C = D.getSingleClause<ClauseTy>();
+  assert(!C->varlist_empty() &&
+         "ompx_bare requires explicit num_teams and thread_limit");
+  CodeGenFunction::RunCleanupsScope Scope(CGF);
+  for (auto *E : C->varlist()) {
     llvm::Value *V = CGF.EmitScalarExpr(E);
-    NumTeams.push_back(
+    Valuess.push_back(
         CGF.Builder.CreateIntCast(V, CGF.Int32Ty, /*isSigned=*/true));
   }
 }
@@ -9672,14 +9674,17 @@ static void emitTargetCallKernelLaunch(
 
     bool IsBare = D.hasClausesOfKind<OMPXBareClause>();
     SmallVector<llvm::Value *, 3> NumTeams;
-    if (IsBare)
-      emitNumTeamsForBareTargetDirective(CGF, D, NumTeams);
-    else
+    SmallVector<llvm::Value *, 3> NumThreads;
+    if (IsBare) {
+      emitClauseForBareTargetDirective<OMPNumTeamsClause>(CGF, D, NumTeams);
+      emitClauseForBareTargetDirective<OMPThreadLimitClause>(CGF, D,
+                                                             NumThreads);
+    } else {
       NumTeams.push_back(OMPRuntime->emitNumTeamsForTargetDirective(CGF, D));
+      NumThreads.push_back(OMPRuntime->emitNumThreadsForTargetDirective(CGF, D));
+    }
 
     llvm::Value *DeviceID = emitDeviceID(Device, CGF);
-    llvm::Value *NumThreads =
-        OMPRuntime->emitNumThreadsForTargetDirective(CGF, D);
     llvm::Value *RTLoc = OMPRuntime->emitUpdateLocation(CGF, D.getBeginLoc());
     llvm::Value *NumIterations =
         OMPRuntime->emitTargetNumIterationsCall(CGF, D, SizeEmitter);
diff --git a/clang/test/OpenMP/target_teams_codegen.cpp b/clang/test/OpenMP/target_teams_codegen.cpp
index 9cab8eef148833..13d44e127201bd 100644
--- a/clang/test/OpenMP/target_teams_codegen.cpp
+++ b/clang/test/OpenMP/target_teams_codegen.cpp
@@ -127,13 +127,13 @@ int foo(int n) {
     aa += 1;
   }
 
-  #pragma omp target teams ompx_bare num_teams(1, 2) thread_limit(1)
+  #pragma omp target teams ompx_bare num_teams(1, 2) thread_limit(1, 2)
   {
     a += 1;
     aa += 1;
   }
 
-  #pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(1)
+  #pragma omp target teams ompx_bare num_teams(1, 2, 3) thread_limit(1, 2, 3)
   {
     a += 1;
     aa += 1;
@@ -667,7 +667,7 @@ int bar(int n){
 // CHECK1-NEXT:    [[TMP144:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 10
 // CHECK1-NEXT:    store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP144]], align 4
 // CHECK1-NEXT:    [[TMP145:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 11
-// CHECK1-NEXT:    store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP145]], align 4
+// CHECK1-NEXT:    store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP145]], align 4
 // CHECK1-NEXT:    [[TMP146:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 12
 // CHECK1-NEXT:    store i32 0, ptr [[TMP146]], align 4
 // CHECK1-NEXT:    [[TMP147:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l130.region_id, ptr [[KERNEL_ARGS29]])
@@ -720,7 +720,7 @@ int bar(int n){
 // CHECK1-NEXT:    [[TMP171:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 10
 // CHECK1-NEXT:    store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP171]], align 4
 // CHECK1-NEXT:    [[TMP172:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 11
-// CHECK1-NEXT:    store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP172]], align 4
+// CHECK1-NEXT:    store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP172]], align 4
 // CHECK1-NEXT:    [[TMP173:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 12
 // CHECK1-NEXT:    store i32 0, ptr [[TMP173]], align 4
 // CHECK1-NEXT:    [[TMP174:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l136.region_id, ptr [[KERNEL_ARGS37]])
@@ -2458,7 +2458,7 @@ int bar(int n){
 // CHECK3-NEXT:    [[TMP142:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 10
 // CHECK3-NEXT:    store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP142]], align 4
 // CHECK3-NEXT:    [[TMP143:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 11
-// CHECK3-NEXT:    store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP143]], align 4
+// CHECK3-NEXT:    store [3 x i32] [i32 1, i32 2, i32 0], ptr [[TMP143]], align 4
 // CHECK3-NEXT:    [[TMP144:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS29]], i32 0, i32 12
 // CHECK3-NEXT:    store i32 0, ptr [[TMP144]], align 4
 // CHECK3-NEXT:    [[TMP145:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l130.region_id, ptr [[KERNEL_ARGS29]])
@@ -2511,7 +2511,7 @@ int bar(int n){
 // CHECK3-NEXT:    [[TMP169:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 10
 // CHECK3-NEXT:    store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP169]], align 4
 // CHECK3-NEXT:    [[TMP170:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 11
-// CHECK3-NEXT:    store [3 x i32] [i32 1, i32 0, i32 0], ptr [[TMP170]], align 4
+// CHECK3-NEXT:    store [3 x i32] [i32 1, i32 2, i32 3], ptr [[TMP170]], align 4
 // CHECK3-NEXT:    [[TMP171:%.*]] = getelementptr inbounds nuw [[STRUCT___TGT_KERNEL_ARGUMENTS]], ptr [[KERNEL_ARGS37]], i32 0, i32 12
 // CHECK3-NEXT:    store i32 0, ptr [[TMP171]], align 4
 // CHECK3-NEXT:    [[TMP172:%.*]] = call i32 @__tgt_target_kernel(ptr @[[GLOB1]], i64 -1, i32 1, i32 1, ptr @.{{__omp_offloading_[0-9a-z]+_[0-9a-z]+}}__Z3fooi_l136.region_id, ptr [[KERNEL_ARGS37]])
diff --git a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
index 9e4e7ebf2a5703..4be0159fb1dd9f 100644
--- a/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
+++ b/llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h
@@ -2195,7 +2195,7 @@ class OpenMPIRBuilder {
     /// The number of teams.
     ArrayRef<Value *> NumTeams;
     /// The number of threads.
-    Value *NumThreads = nullptr;
+    ArrayRef<Value *> NumThreads;
     /// The size of the dynamic shared memory.
     Value *DynCGGroupMem = nullptr;
     /// True if the kernel has 'no wait' clause.
@@ -2205,7 +2205,8 @@ class OpenMPIRBuilder {
     TargetKernelArgs() {}
     TargetKernelArgs(unsigned NumTargetItems, TargetDataRTArgs RTArgs,
                      Value *NumIterations, ArrayRef<Value *> NumTeams,
-                     Value *NumThreads, Value *DynCGGroupMem, bool HasNoWait)
+                     ArrayRef<Value *> NumThreads, Value *DynCGGroupMem,
+                     bool HasNoWait)
         : NumTargetItems(NumTargetItems), RTArgs(RTArgs),
           NumIterations(NumIterations), NumTeams(NumTeams),
           NumThreads(NumThreads), DynCGGroupMem(DynCGGroupMem),
@@ -2852,17 +2853,16 @@ class OpenMPIRBuilder {
   /// instructions for passed in target arguments where neccessary
   /// \param Dependencies A vector of DependData objects that carry
   // dependency information as passed in the depend clause
-  InsertPointTy createTarget(const LocationDescription &Loc,
-                             bool IsOffloadEntry,
-                             OpenMPIRBuilder::InsertPointTy AllocaIP,
-                             OpenMPIRBuilder::InsertPointTy CodeGenIP,
-                             TargetRegionEntryInfo &EntryInfo,
-                             ArrayRef<int32_t> NumTeams, int32_t NumThreads,
-                             SmallVectorImpl<Value *> &Inputs,
-                             GenMapInfoCallbackTy GenMapInfoCB,
-                             TargetBodyGenCallbackTy BodyGenCB,
-                             TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
-                             SmallVector<DependData> Dependencies = {});
+  InsertPointTy
+  createTarget(const LocationDescription &Loc, bool IsOffloadEntry,
+               OpenMPIRBuilder::InsertPointTy AllocaIP,
+               OpenMPIRBuilder::InsertPointTy CodeGenIP,
+               TargetRegionEntryInfo &EntryInfo, ArrayRef<int32_t> NumTeams,
+               ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Inputs,
+               GenMapInfoCallbackTy GenMapInfoCB,
+               TargetBodyGenCallbackTy BodyGenCB,
+               TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
+               SmallVector<DependData> Dependencies = {});
 
   /// Returns __kmpc_for_static_init_* runtime function for the specified
   /// size \a IVSize and sign \a IVSigned. Will create a distribute call
diff --git a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
index b481520fa6c6f9..f46531cb3bad40 100644
--- a/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
+++ b/llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp
@@ -505,11 +505,14 @@ void OpenMPIRBuilder::getKernelArgsVector(TargetKernelArgs &KernelArgs,
 
   Value *NumTeams3D =
       Builder.CreateInsertValue(ZeroArray, KernelArgs.NumTeams[0], {0});
+  Value *NumThreads3D =
+      Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads[0], {0});
   for (unsigned I = 1; I < std::min(KernelArgs.NumTeams.size(), MaxDim); ++I)
     NumTeams3D =
         Builder.CreateInsertValue(NumTeams3D, KernelArgs.NumTeams[I], {I});
-  Value *NumThreads3D =
-      Builder.CreateInsertValue(ZeroArray, KernelArgs.NumThreads, {0});
+  for (unsigned I = 1; I < std::min(KernelArgs.NumThreads.size(), MaxDim); ++I)
+    NumThreads3D =
+        Builder.CreateInsertValue(NumThreads3D, KernelArgs.NumTeams[I], {I});
 
   ArgsVector = {Version,
                 PointerNum,
@@ -1114,9 +1117,9 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::emitKernelLaunch(
   // __tgt_target_teams() launches a GPU kernel with the requested number
   // of teams and threads so no additional calls to the runtime are required.
   // Check the error code and execute the host version if required.
-  Builder.restoreIP(emitTargetKernel(Builder, AllocaIP, Return, RTLoc, DeviceID,
-                                     Args.NumTeams.front(), Args.NumThreads,
-                                     OutlinedFnID, ArgsVector));
+  Builder.restoreIP(emitTargetKernel(
+      Builder, AllocaIP, Return, RTLoc, DeviceID, Args.NumTeams.front(),
+      Args.NumThreads.front(), OutlinedFnID, ArgsVector));
 
   BasicBlock *OffloadFailedBlock =
       BasicBlock::Create(Builder.getContext(), "omp_offload.failed");
@@ -7075,8 +7078,8 @@ void OpenMPIRBuilder::emitOffloadingArraysAndArgs(
 static void emitTargetCall(
     OpenMPIRBuilder &OMPBuilder, IRBuilderBase &Builder,
     OpenMPIRBuilder::InsertPointTy AllocaIP, Function *OutlinedFn,
-    Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams, int32_t NumThreads,
-    SmallVectorImpl<Value *> &Args,
+    Constant *OutlinedFnID, ArrayRef<int32_t> NumTeams,
+    ArrayRef<int32_t> NumThreads, SmallVectorImpl<Value *> &Args,
     OpenMPIRBuilder::GenMapInfoCallbackTy GenMapInfoCB,
     SmallVector<llvm::OpenMPIRBuilder::DependData> Dependencies = {}) {
   // Generate a function call to the host fallback implementation of the target
@@ -7123,13 +7126,15 @@ static void emitTargetCall(
                                          /*ForEndCall=*/false);
 
   SmallVector<Value *, 3> NumTeamsC;
+  SmallVector<Value *, 3> NumThreadsC;
   for (auto V : NumTeams)
     NumTeamsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
+  for (auto V : NumThreads)
+    NumThreadsC.push_back(llvm::ConstantInt::get(Builder.getInt32Ty(), V));
 
   unsigned NumTargetItems = Info.NumberOfPtrs;
   // TODO: Use correct device ID
   Value *DeviceID = Builder.getInt64(OMP_DEVICEID_UNDEF);
-  Value *NumThreadsVal = Builder.getInt32(NumThreads);
   uint32_t SrcLocStrSize;
   Constant *SrcLocStr = OMPBuilder.getOrCreateDefaultSrcLocStr(SrcLocStrSize);
   Value *RTLoc = OMPBuilder.getOrCreateIdent(SrcLocStr, SrcLocStrSize,
@@ -7140,8 +7145,8 @@ static void emitTargetCall(
   Value *DynCGGroupMem = Builder.getInt32(0);
 
   OpenMPIRBuilder::TargetKernelArgs KArgs(NumTargetItems, RTArgs, NumIterations,
-                                          NumTeamsC, NumThreadsVal,
-                                          DynCGGroupMem, HasNoWait);
+                                          NumTeamsC, NumThreadsC, DynCGGroupMem,
+                                          HasNoWait);
 
   // The presence of certain clauses on the target directive require the
   // explicit generation of the target task.
@@ -7159,11 +7164,11 @@ static void emitTargetCall(
 OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
     const LocationDescription &Loc, bool IsOffloadEntry, InsertPointTy AllocaIP,
     InsertPointTy CodeGenIP, TargetRegionEntryInfo &EntryInfo,
-    ArrayRef<int32_t> NumTeams, int32_t NumThreads,
+    ArrayRef<int32_t> NumTeams, ArrayRef<int32_t> NumThreads,
     SmallVectorImpl<Value *> &Args, GenMapInfoCallbackTy GenMapInfoCB,
     OpenMPIRBuilder::TargetBodyGenCallbackTy CBFunc,
     OpenMPIRBuilder::TargetGenArgAccessorsCallbackTy ArgAccessorFuncCB,
-    SmallVector<DependData> Dependenciess) {
+    SmallVector<DependData> Dependencies) {
 
   if (!updateToLocation(Loc))
     return InsertPointTy();
@@ -7184,7 +7189,7 @@ OpenMPIRBuilder::InsertPointTy OpenMPIRBuilder::createTarget(
   // that represents the target region. Do that now.
   if (!Config.isTargetDevice())
     emitTargetCall(*this, Builder, AllocaIP, OutlinedFn, OutlinedFnID, NumTeams,
-                   NumThreads, Args, GenMapInfoCB, Dependenciess);
+                   NumThreads, Args, GenMapInfoCB, Dependencies);
   return Builder.saveIP();
 }
 

``````````

</details>


https://github.com/llvm/llvm-project/pull/102717


More information about the llvm-branch-commits mailing list