[Mlir-commits] [mlir] [OpenMP][MLIR] Add thread_limit mlir->llvm lowering (PR #179608)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 10 05:31:22 PST 2026


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-flang-openmp

Author: Chaitanya (skc7)

<details>
<summary>Changes</summary>

`thread_limit` in omp mlir now supports multi-dimension.
This PR updates OpenMPToLLVMIRTranslation with below changes:

Updates `checkThreadLimit` to allow mult-dim `thread_limit` upto 3 dimensions.

Updates `initTargetDefaultAttrs` to extract all `thread_limit` dimensions from `teamsOp.getThreadLimitVars()`. Uses first dimension for compile-time constant extraction.

Updates `initTargetRuntimeAttrs` to extracts all dimensions into `threadLimitVars` vector.

---
Full diff: https://github.com/llvm/llvm-project/pull/179608.diff


4 Files Affected:

- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+47-20) 
- (modified) mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir (+3-3) 
- (modified) mlir/test/Target/LLVMIR/openmp-teams.mlir (+36) 
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (+3-3) 


``````````diff
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 022322502a755..f364e1e1818c3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -386,8 +386,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
   };
 
   auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
-    if (op.hasThreadLimitMultiDim())
-      result = todo("thread_limit with multi-dimensional values");
+    if (op.getThreadLimitDimsCount() > 3)
+      result = todo("thread_limit with more than 3 dimensions");
   };
 
   LogicalResult result = success();
@@ -6040,7 +6040,7 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
 static void
 extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
                        Value &numTeamsLower, Value &numTeamsUpper,
-                       Value &threadLimit,
+                       llvm::SmallVectorImpl<Value> &threadLimitVars,
                        llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
                        llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
                        llvm::SmallVectorImpl<Value> *steps = nullptr) {
@@ -6057,10 +6057,18 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
             else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
                                         blockArg))
               numTeamsUpper = hostEvalVar;
-            else if (!teamsOp.getThreadLimitVars().empty() &&
-                     teamsOp.getThreadLimit(0) == blockArg)
-              threadLimit = hostEvalVar;
-            else
+            else if (llvm::is_contained(teamsOp.getThreadLimitVars(),
+                                        blockArg)) {
+              for (auto [i, limitVar] :
+                   llvm::enumerate(teamsOp.getThreadLimitVars())) {
+                if (limitVar == blockArg) {
+                  if (threadLimitVars.size() <= i)
+                    threadLimitVars.resize(i + 1);
+                  threadLimitVars[i] = hostEvalVar;
+                  break;
+                }
+              }
+            } else
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
@@ -6168,10 +6176,11 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
                        bool isTargetDevice, bool isGPU) {
   // TODO: Handle constant 'if' clauses.
 
-  Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+  Value numThreads, numTeamsLower, numTeamsUpper;
+  llvm::SmallVector<Value> threadLimitVars;
   if (!isTargetDevice) {
     extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
-                           threadLimit);
+                           threadLimitVars);
   } else {
     // In the target device, values for these clauses are not passed as
     // host_eval, but instead evaluated prior to entry to the region. This
@@ -6181,8 +6190,9 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
       // Handle num_teams upper bounds (only first value for now)
       if (!teamsOp.getNumTeamsUpperVars().empty())
         numTeamsUpper = teamsOp.getNumTeams(0);
-      if (!teamsOp.getThreadLimitVars().empty())
-        threadLimit = teamsOp.getThreadLimit(0);
+      threadLimitVars.reserve(teamsOp.getThreadLimitVars().size());
+      for (auto limitVar : teamsOp.getThreadLimitVars())
+        threadLimitVars.push_back(limitVar);
     }
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
@@ -6231,7 +6241,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
   int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
   if (!targetOp.getThreadLimitVars().empty())
     setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
-  setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
+  if (!threadLimitVars.empty())
+    setMaxValueFromClause(threadLimitVars[0], teamsThreadLimitVal);
 
   // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
   int32_t maxThreadsVal = -1;
@@ -6278,8 +6289,11 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
 
   attrs.MinTeams = minTeamsVal;
+  // Always resize to 3 dimensions to match TargetKernelRuntimeAttrs
+  attrs.MaxTeams.resize(3, -1);
   attrs.MaxTeams.front() = maxTeamsVal;
   attrs.MinThreads = 1;
+  attrs.MaxThreads.resize(3, -1);
   attrs.MaxThreads.front() = combinedMaxThreadsVal;
   attrs.ReductionDataSize = reductionDataSize;
   // TODO: Allow modified buffer length similar to
@@ -6302,17 +6316,22 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
   unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
 
-  Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+  Value numThreads, numTeamsLower, numTeamsUpper;
+  llvm::SmallVector<Value> threadLimitVars;
   llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
       steps(numLoops);
   extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
-                         teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
+                         threadLimitVars, &lowerBounds, &upperBounds, &steps);
 
   // TODO: Handle constant 'if' clauses.
+  // Resize to 3 dimensions to match TargetKernelDefaultAttrs
+  attrs.TargetThreadLimit.resize(3);
   if (!targetOp.getThreadLimitVars().empty()) {
-    Value targetThreadLimit = targetOp.getThreadLimit(0);
-    attrs.TargetThreadLimit.front() =
-        moduleTranslation.lookupValue(targetThreadLimit);
+    for (auto [i, limitVar] : llvm::enumerate(targetOp.getThreadLimitVars())) {
+      if (limitVar) {
+        attrs.TargetThreadLimit[i] = moduleTranslation.lookupValue(limitVar);
+      }
+    }
   }
 
   // The __kmpc_push_num_teams_51 function expects int32 as the arguments.  So,
@@ -6322,13 +6341,21 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
     attrs.MinTeams = builder.CreateSExtOrTrunc(
         moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
 
+  // Resize to 3 dimensions to match TargetKernelDefaultAttrs
+  attrs.MaxTeams.resize(3);
   if (numTeamsUpper)
     attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
         moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
 
-  if (teamsThreadLimit)
-    attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
-        moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
+  attrs.TeamsThreadLimit.resize(3);
+  if (!threadLimitVars.empty()) {
+    for (auto [i, limitVar] : llvm::enumerate(threadLimitVars)) {
+      if (limitVar) {
+        attrs.TeamsThreadLimit[i] = builder.CreateSExtOrTrunc(
+            moduleTranslation.lookupValue(limitVar), builder.getInt32Ty());
+      }
+    }
+  }
 
   if (numThreads)
     attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
index abc67017b620d..9cc9a5025b613 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
@@ -2,13 +2,13 @@
 
 // CHECK: define void @main(i32 %[[NUM_TEAMS_ARG:.*]])
 // CHECK: %[[KERNEL_ARGS:.*]] = alloca %struct.__tgt_kernel_arguments
-// CHECK: %[[NUM_TEAMS:.*]] = insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
+// CHECK: insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
 
 // CHECK: %[[NUM_TEAMS_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 10
-// CHECK: store [3 x i32] %[[NUM_TEAMS]], ptr %[[NUM_TEAMS_KARG]], align 4
+// CHECK-NEXT: store [3 x i32] %{{.*}}, ptr %[[NUM_TEAMS_KARG]], align 4
 
 // CHECK: %[[NUM_THREADS_ARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 11
-// CHECK: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
+// CHECK-NEXT: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
 
 // CHECK: %{{.*}} = call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 %[[NUM_TEAMS_ARG]], i32 [[NUM_THREADS:10]], ptr @.[[OUTLINED_FN:.*]].region_id, ptr %[[KERNEL_ARGS]])
 // CHECK: call void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_ARG]])
diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir
index 4690b51122beb..adca15d1c5fcc 100644
--- a/mlir/test/Target/LLVMIR/openmp-teams.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir
@@ -311,3 +311,39 @@ llvm.func @teams_if_with_num_teams(%condition: i1, %numTeamsLower: i32, %numTeam
     llvm.call @afterTeams() : () -> ()
     llvm.return
 }
+
+// -----
+
+llvm.func @duringTeams()
+
+// CHECK-LABEL: @omp_teams_thread_limit_2d
+// CHECK-SAME: (i32 [[LIMIT_X:.+]], i32 [[LIMIT_Y:.+]])
+llvm.func @omp_teams_thread_limit_2d(%limitX: i32, %limitY: i32) {
+    // Multi-dimensional thread_limit: all dimensions are passed
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[LIMIT_X]])
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @{{[0-9]+}}, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams thread_limit(%limitX, %limitY : i32, i32) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    llvm.return
+}
+
+// -----
+
+llvm.func @duringTeams()
+
+// CHECK-LABEL: @omp_teams_thread_limit_3d
+// CHECK-SAME: (i32 [[LIMIT_X:.+]], i64 [[LIMIT_Y:.+]], i16 [[LIMIT_Z:.+]])
+llvm.func @omp_teams_thread_limit_3d(%limitX: i32, %limitY: i64, %limitZ: i16) {
+    // Multi-dimensional thread_limit with mixed types: all dimensions are passed
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[LIMIT_X]])
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @{{[0-9]+}}, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams thread_limit(%limitX, %limitY, %limitZ : i32, i64, i16) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 36338e5cb1bed..ec1c134c2c07a 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -454,10 +454,10 @@ llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
 
 // -----
 
-llvm.func @teams_thread_limit_multi_dim(%lb : i32, %ub : i32) {
-  // expected-error at below {{not yet implemented: Unhandled clause thread_limit with multi-dimensional values in omp.teams operation}}
+llvm.func @teams_thread_limit_too_many_dims(%lb : i32, %ub : i32) {
+  // expected-error at below {{not yet implemented: Unhandled clause thread_limit with more than 3 dimensions in omp.teams operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.teams}}
-  omp.teams thread_limit(%lb, %ub : i32, i32) {
+  omp.teams thread_limit(%lb, %ub, %lb, %ub : i32, i32, i32, i32) {
     omp.terminator
   }
   llvm.return

``````````

</details>


https://github.com/llvm/llvm-project/pull/179608


More information about the Mlir-commits mailing list