[llvm-branch-commits] [mlir] [OpenMP][MLIR] Add thread_limit mlir->llvm lowering (PR #179608)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Apr 16 03:33:54 PDT 2026


https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/179608

>From 6adef736e0e5f304c0535012095cf8261631c237 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 4 Feb 2026 10:21:06 +0530
Subject: [PATCH 1/3] [OpenMP][MLIR] Add thread_limit mlir->llvm lowering

---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 131 +++++++++++-------
 .../LLVMIR/openmp-target-launch-host.mlir     |   6 +-
 mlir/test/Target/LLVMIR/openmp-teams.mlir     |  36 +++++
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |   6 +-
 4 files changed, 124 insertions(+), 55 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 64c7e5700c771..ad7705fc014a5 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -390,8 +390,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
   };
 
   auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
-    if (op.hasThreadLimitMultiDim())
-      result = todo("thread_limit with multi-dimensional values");
+    if (op.getThreadLimitDimsCount() > 3)
+      result = todo("thread_limit with more than 3 dimensions");
   };
 
   LogicalResult result = success();
@@ -6508,12 +6508,14 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
 ///
 /// Loop bounds and steps are only optionally populated, if output vectors are
 /// provided.
-static void extractHostEvalClauses(
-    omp::TargetOp targetOp, llvm::SmallVectorImpl<Value> &numThreadsVars,
-    Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit,
-    llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
-    llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
-    llvm::SmallVectorImpl<Value> *steps = nullptr) {
+static void
+extractHostEvalClauses(omp::TargetOp targetOp,
+                       llvm::SmallVectorImpl<Value> &numThreadsVars,
+                       Value &numTeamsLower, Value &numTeamsUpper,
+                       llvm::SmallVectorImpl<Value> &threadLimitVars,
+                       llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
+                       llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
+                       llvm::SmallVectorImpl<Value> *steps = nullptr) {
   auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
   for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
                                    blockArgIface.getHostEvalBlockArgs())) {
@@ -6527,10 +6529,18 @@ static void extractHostEvalClauses(
             else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
                                         blockArg))
               numTeamsUpper = hostEvalVar;
-            else if (!teamsOp.getThreadLimitVars().empty() &&
-                     teamsOp.getThreadLimit(0) == blockArg)
-              threadLimit = hostEvalVar;
-            else
+            else if (llvm::is_contained(teamsOp.getThreadLimitVars(),
+                                        blockArg)) {
+              for (auto [i, limitVar] :
+                   llvm::enumerate(teamsOp.getThreadLimitVars())) {
+                if (limitVar == blockArg) {
+                  if (threadLimitVars.size() <= i)
+                    threadLimitVars.resize(i + 1);
+                  threadLimitVars[i] = hostEvalVar;
+                  break;
+                }
+              }
+            } else
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
@@ -6653,11 +6663,11 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
                        bool isTargetDevice, bool isGPU) {
   // TODO: Handle constant 'if' clauses.
 
-  Value numTeamsLower, numTeamsUpper, threadLimit;
-  llvm::SmallVector<Value> numThreadsVars;
+  Value numTeamsLower, numTeamsUpper;
+  llvm::SmallVector<Value> numThreadsVars, threadLimitVars;
   if (!isTargetDevice) {
     extractHostEvalClauses(targetOp, numThreadsVars, numTeamsLower,
-                           numTeamsUpper, threadLimit);
+                           numTeamsUpper, threadLimitVars);
   } else {
     // In the target device, values for these clauses are not passed as
     // host_eval, but instead evaluated prior to entry to the region. This
@@ -6667,8 +6677,9 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
       // Handle num_teams upper bounds (only first value for now)
       if (!teamsOp.getNumTeamsUpperVars().empty())
         numTeamsUpper = teamsOp.getNumTeams(0);
-      if (!teamsOp.getThreadLimitVars().empty())
-        threadLimit = teamsOp.getThreadLimit(0);
+      threadLimitVars.reserve(teamsOp.getThreadLimitVars().size());
+      for (auto limitVar : teamsOp.getThreadLimitVars())
+        threadLimitVars.push_back(limitVar);
     }
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
@@ -6715,33 +6726,45 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
       result = 0;
   };
 
-  // Extract 'thread_limit' clause from 'target' and 'teams' directives.
-  int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
-  if (!targetOp.getThreadLimitVars().empty())
-    setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
-  setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
+  // Extract multi-dimensional 'thread_limit' clause from 'target' and 'teams'.
+  llvm::SmallVector<int32_t, 3> targetThreadLimitVals(3, -1);
+  llvm::SmallVector<int32_t, 3> teamsThreadLimitVals(3, -1);
+  for (auto [i, limitVar] : llvm::enumerate(targetOp.getThreadLimitVars())) {
+    if (i < 3)
+      setMaxValueFromClause(limitVar, targetThreadLimitVals[i]);
+  }
+  for (auto [i, limitVar] : llvm::enumerate(threadLimitVars)) {
+    if (i < 3)
+      setMaxValueFromClause(limitVar, teamsThreadLimitVals[i]);
+  }
 
-  // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
-  int32_t maxThreadsVal = -1;
+  // Extract multi-dimensional 'num_threads' clause from 'parallel' or set to 1
+  // if it's SIMD.
+  llvm::SmallVector<int32_t, 3> maxThreadsVals(3, -1);
   if (castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
-    // For multi-dimensional num_threads, only use the first dimension for now
-    if (!numThreadsVars.empty())
-      setMaxValueFromClause(numThreadsVars[0], maxThreadsVal);
+    for (auto [i, threadsVar] : llvm::enumerate(numThreadsVars)) {
+      if (i < 3)
+        setMaxValueFromClause(threadsVar, maxThreadsVals[i]);
+    }
   } else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
-                                                /*immediateParent=*/true))
-    maxThreadsVal = 1;
+                                                /*immediateParent=*/true)) {
+    maxThreadsVals[0] = 1;
+  }
 
   // For max values, < 0 means unset, == 0 means set but unknown. Select the
-  // minimum value between 'max_threads' and 'thread_limit' clauses that were
-  // set.
-  int32_t combinedMaxThreadsVal = targetThreadLimitVal;
-  if (combinedMaxThreadsVal < 0 ||
-      (teamsThreadLimitVal >= 0 && teamsThreadLimitVal < combinedMaxThreadsVal))
-    combinedMaxThreadsVal = teamsThreadLimitVal;
-
-  if (combinedMaxThreadsVal < 0 ||
-      (maxThreadsVal >= 0 && maxThreadsVal < combinedMaxThreadsVal))
-    combinedMaxThreadsVal = maxThreadsVal;
+  // minimum value between 'num_threads' and 'thread_limit' clauses that were
+  // set, for each dimension.
+  llvm::SmallVector<int32_t, 3> combinedMaxThreadsVals(3, -1);
+  for (size_t i = 0; i < 3; ++i) {
+    int32_t combined = targetThreadLimitVals[i];
+    if (combined < 0 ||
+        (teamsThreadLimitVals[i] >= 0 && teamsThreadLimitVals[i] < combined))
+      combined = teamsThreadLimitVals[i];
+    if (combined < 0 ||
+        (maxThreadsVals[i] >= 0 && maxThreadsVals[i] < combined))
+      combined = maxThreadsVals[i];
+    combinedMaxThreadsVals[i] = combined;
+  }
 
   int32_t reductionDataSize = 0;
   if (isGPU && capturedOp) {
@@ -6770,7 +6793,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
   attrs.MinTeams = minTeamsVal;
   attrs.MaxTeams.front() = maxTeamsVal;
   attrs.MinThreads = 1;
-  attrs.MaxThreads.front() = combinedMaxThreadsVal;
+  attrs.MaxThreads = combinedMaxThreadsVals;
   attrs.ReductionDataSize = reductionDataSize;
   // TODO: Allow modified buffer length similar to
   // fopenmp-cuda-teams-reduction-recs-num flag in clang.
@@ -6792,18 +6815,22 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
   unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
 
-  Value numTeamsLower, numTeamsUpper, teamsThreadLimit;
-  llvm::SmallVector<Value> numThreadsVars;
+  Value numTeamsLower, numTeamsUpper;
+  llvm::SmallVector<Value> numThreadsVars, threadLimitVars;
   llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
       steps(numLoops);
   extractHostEvalClauses(targetOp, numThreadsVars, numTeamsLower, numTeamsUpper,
-                         teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
+                         threadLimitVars, &lowerBounds, &upperBounds, &steps);
 
   // TODO: Handle constant 'if' clauses.
+  // Resize to 3 dimensions to match TargetKernelDefaultAttrs
+  attrs.TargetThreadLimit.resize(3);
   if (!targetOp.getThreadLimitVars().empty()) {
-    Value targetThreadLimit = targetOp.getThreadLimit(0);
-    attrs.TargetThreadLimit.front() =
-        moduleTranslation.lookupValue(targetThreadLimit);
+    for (auto [i, limitVar] : llvm::enumerate(targetOp.getThreadLimitVars())) {
+      if (limitVar) {
+        attrs.TargetThreadLimit[i] = moduleTranslation.lookupValue(limitVar);
+      }
+    }
   }
 
   // The __kmpc_push_num_teams_51 function expects int32 as the arguments.  So,
@@ -6817,9 +6844,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
     attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
         moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
 
-  if (teamsThreadLimit)
-    attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
-        moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
+  attrs.TeamsThreadLimit.resize(3);
+  if (!threadLimitVars.empty()) {
+    for (auto [i, limitVar] : llvm::enumerate(threadLimitVars)) {
+      if (limitVar) {
+        attrs.TeamsThreadLimit[i] = builder.CreateSExtOrTrunc(
+            moduleTranslation.lookupValue(limitVar), builder.getInt32Ty());
+      }
+    }
+  }
 
   // Handle multi-dimensional num_threads (only first value for now)
   if (!numThreadsVars.empty())
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
index deb1e6cef50bd..a8b8696190604 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
@@ -2,13 +2,13 @@
 
 // CHECK: define void @main(i32 %[[NUM_TEAMS_ARG:.*]])
 // CHECK: %[[KERNEL_ARGS:.*]] = alloca %struct.__tgt_kernel_arguments
-// CHECK: %[[NUM_TEAMS:.*]] = insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
+// CHECK: insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
 
 // CHECK: %[[NUM_TEAMS_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 10
-// CHECK: store [3 x i32] %[[NUM_TEAMS]], ptr %[[NUM_TEAMS_KARG]], align 4
+// CHECK-NEXT: store [3 x i32] %{{.*}}, ptr %[[NUM_TEAMS_KARG]], align 4
 
 // CHECK: %[[NUM_THREADS_ARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 11
-// CHECK: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
+// CHECK-NEXT: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
 
 // CHECK: %{{.*}} = call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 %[[NUM_TEAMS_ARG]], i32 [[NUM_THREADS:10]], ptr @.[[OUTLINED_FN:.*]].region_id, ptr %[[KERNEL_ARGS]])
 // CHECK: call void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_ARG]], ptr null)
diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir
index 4690b51122beb..adca15d1c5fcc 100644
--- a/mlir/test/Target/LLVMIR/openmp-teams.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir
@@ -311,3 +311,39 @@ llvm.func @teams_if_with_num_teams(%condition: i1, %numTeamsLower: i32, %numTeam
     llvm.call @afterTeams() : () -> ()
     llvm.return
 }
+
+// -----
+
+llvm.func @duringTeams()
+
+// CHECK-LABEL: @omp_teams_thread_limit_2d
+// CHECK-SAME: (i32 [[LIMIT_X:.+]], i32 [[LIMIT_Y:.+]])
+llvm.func @omp_teams_thread_limit_2d(%limitX: i32, %limitY: i32) {
+    // Multi-dimensional thread_limit: all dimensions are passed
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[LIMIT_X]])
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @{{[0-9]+}}, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams thread_limit(%limitX, %limitY : i32, i32) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    llvm.return
+}
+
+// -----
+
+llvm.func @duringTeams()
+
+// CHECK-LABEL: @omp_teams_thread_limit_3d
+// CHECK-SAME: (i32 [[LIMIT_X:.+]], i64 [[LIMIT_Y:.+]], i16 [[LIMIT_Z:.+]])
+llvm.func @omp_teams_thread_limit_3d(%limitX: i32, %limitY: i64, %limitZ: i16) {
+    // Multi-dimensional thread_limit with mixed types: all dimensions are passed
+    // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+    // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[LIMIT_X]])
+    // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @{{[0-9]+}}, i32 0, ptr [[OUTLINED_FN:.+]])
+    omp.teams thread_limit(%limitX, %limitY, %limitZ : i32, i64, i16) {
+        llvm.call @duringTeams() : () -> ()
+        omp.terminator
+    }
+    llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 1d85806bfaf55..6a2a78a4f1f8b 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -479,10 +479,10 @@ llvm.func @parallel_num_threads_too_many_dims(%lb : i32, %ub : i32) {
 
 // -----
 
-llvm.func @teams_thread_limit_multi_dim(%lb : i32, %ub : i32) {
-  // expected-error at below {{not yet implemented: Unhandled clause thread_limit with multi-dimensional values in omp.teams operation}}
+llvm.func @teams_thread_limit_too_many_dims(%lb : i32, %ub : i32) {
+  // expected-error at below {{not yet implemented: Unhandled clause thread_limit with more than 3 dimensions in omp.teams operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.teams}}
-  omp.teams thread_limit(%lb, %ub : i32, i32) {
+  omp.teams thread_limit(%lb, %ub, %lb, %ub : i32, i32, i32, i32) {
     omp.terminator
   }
   llvm.return

>From 327f5748a6ef33f383d1d76e743f80a2fd90dcc5 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 9 Apr 2026 16:23:40 +0530
Subject: [PATCH 2/3] update

---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 88 +++++++++++--------
 .../LLVMIR/openmp-target-launch-device.mlir   | 21 +++++
 .../LLVMIR/openmp-target-launch-host.mlir     | 27 +++++-
 3 files changed, 94 insertions(+), 42 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index ad7705fc014a5..20b757326ac55 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6726,26 +6726,27 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
       result = 0;
   };
 
-  // Extract multi-dimensional 'thread_limit' clause from 'target' and 'teams'.
-  llvm::SmallVector<int32_t, 3> targetThreadLimitVals(3, -1);
-  llvm::SmallVector<int32_t, 3> teamsThreadLimitVals(3, -1);
-  for (auto [i, limitVar] : llvm::enumerate(targetOp.getThreadLimitVars())) {
-    if (i < 3)
-      setMaxValueFromClause(limitVar, targetThreadLimitVals[i]);
-  }
-  for (auto [i, limitVar] : llvm::enumerate(threadLimitVars)) {
-    if (i < 3)
-      setMaxValueFromClause(limitVar, teamsThreadLimitVals[i]);
-  }
-
-  // Extract multi-dimensional 'num_threads' clause from 'parallel' or set to 1
-  // if it's SIMD.
-  llvm::SmallVector<int32_t, 3> maxThreadsVals(3, -1);
+  // Extract 'thread_limit' clause from 'target' and 'teams'. The number of
+  // dimensions is determined by the clauses present (the >3 dims check in
+  // checkImplementationStatus guards against unsupported counts).
+  size_t numTargetDims = targetOp.getThreadLimitVars().size();
+  size_t numTeamsDims = threadLimitVars.size();
+  size_t numParallelDims = numThreadsVars.size();
+  size_t numDims =
+      std::max({numTargetDims, numTeamsDims, numParallelDims, size_t(1)});
+
+  llvm::SmallVector<int32_t, 3> targetThreadLimitVals(numDims, -1);
+  llvm::SmallVector<int32_t, 3> teamsThreadLimitVals(numDims, -1);
+  for (auto [i, limitVar] : llvm::enumerate(targetOp.getThreadLimitVars()))
+    setMaxValueFromClause(limitVar, targetThreadLimitVals[i]);
+  for (auto [i, limitVar] : llvm::enumerate(threadLimitVars))
+    setMaxValueFromClause(limitVar, teamsThreadLimitVals[i]);
+
+  // Extract 'num_threads' clause from 'parallel' or set to 1 if it's SIMD.
+  llvm::SmallVector<int32_t, 3> maxThreadsVals(numDims, -1);
   if (castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
-    for (auto [i, threadsVar] : llvm::enumerate(numThreadsVars)) {
-      if (i < 3)
-        setMaxValueFromClause(threadsVar, maxThreadsVals[i]);
-    }
+    for (auto [i, threadsVar] : llvm::enumerate(numThreadsVars))
+      setMaxValueFromClause(threadsVar, maxThreadsVals[i]);
   } else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
                                                 /*immediateParent=*/true)) {
     maxThreadsVals[0] = 1;
@@ -6754,8 +6755,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
   // For max values, < 0 means unset, == 0 means set but unknown. Select the
   // minimum value between 'num_threads' and 'thread_limit' clauses that were
   // set, for each dimension.
-  llvm::SmallVector<int32_t, 3> combinedMaxThreadsVals(3, -1);
-  for (size_t i = 0; i < 3; ++i) {
+  llvm::SmallVector<int32_t, 3> combinedMaxThreadsVals(numDims, -1);
+  for (size_t i = 0; i < numDims; ++i) {
     int32_t combined = targetThreadLimitVals[i];
     if (combined < 0 ||
         (teamsThreadLimitVals[i] >= 0 && teamsThreadLimitVals[i] < combined))
@@ -6816,21 +6817,22 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
 
   Value numTeamsLower, numTeamsUpper;
-  llvm::SmallVector<Value> numThreadsVars, threadLimitVars;
+  llvm::SmallVector<Value> numThreadsVars, teamsThreadLimitVars;
   llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
       steps(numLoops);
   extractHostEvalClauses(targetOp, numThreadsVars, numTeamsLower, numTeamsUpper,
-                         threadLimitVars, &lowerBounds, &upperBounds, &steps);
+                         teamsThreadLimitVars, &lowerBounds, &upperBounds,
+                         &steps);
 
   // TODO: Handle constant 'if' clauses.
-  // Resize to 3 dimensions to match TargetKernelDefaultAttrs
-  attrs.TargetThreadLimit.resize(3);
   if (!targetOp.getThreadLimitVars().empty()) {
-    for (auto [i, limitVar] : llvm::enumerate(targetOp.getThreadLimitVars())) {
-      if (limitVar) {
-        attrs.TargetThreadLimit[i] = moduleTranslation.lookupValue(limitVar);
-      }
-    }
+    attrs.TargetThreadLimit.clear();
+    llvm::transform(targetOp.getThreadLimitVars(),
+                    std::back_inserter(attrs.TargetThreadLimit),
+                    [&](Value limitVar) -> llvm::Value * {
+                      return limitVar ? moduleTranslation.lookupValue(limitVar)
+                                      : nullptr;
+                    });
   }
 
   // The __kmpc_push_num_teams_51 function expects int32 as the arguments.  So,
@@ -6844,16 +6846,26 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
     attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
         moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
 
-  attrs.TeamsThreadLimit.resize(3);
-  if (!threadLimitVars.empty()) {
-    for (auto [i, limitVar] : llvm::enumerate(threadLimitVars)) {
-      if (limitVar) {
-        attrs.TeamsThreadLimit[i] = builder.CreateSExtOrTrunc(
-            moduleTranslation.lookupValue(limitVar), builder.getInt32Ty());
-      }
-    }
+  if (!teamsThreadLimitVars.empty()) {
+    attrs.TeamsThreadLimit.clear();
+    llvm::transform(teamsThreadLimitVars,
+                    std::back_inserter(attrs.TeamsThreadLimit),
+                    [&](Value limitVar) -> llvm::Value * {
+                      return limitVar
+                                 ? builder.CreateSExtOrTrunc(
+                                       moduleTranslation.lookupValue(limitVar),
+                                       builder.getInt32Ty())
+                                 : nullptr;
+                    });
   }
 
+  // Ensure TargetThreadLimit and TeamsThreadLimit have matching sizes
+  // for zip_equal in OMPIRBuilder.
+  size_t maxDims =
+      std::max(attrs.TargetThreadLimit.size(), attrs.TeamsThreadLimit.size());
+  attrs.TargetThreadLimit.resize(maxDims);
+  attrs.TeamsThreadLimit.resize(maxDims);
+
   // Handle multi-dimensional num_threads (only first value for now)
   if (!numThreadsVars.empty())
     attrs.MaxThreads = moduleTranslation.lookupValue(numThreadsVars[0]);
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
index e27f7fe4b2e7e..8fbe6dbb78e9b 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
@@ -12,6 +12,13 @@
 // CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE2:1]], i32 [[MIN_THREADS2:1]], i32 [[MAX_THREADS2:30]], i32 [[MIN_TEAMS2:40]], i32 [[MAX_TEAMS2:40]], i32 0, i32 0 },
 // CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
 
+// Multi-dim thread_limit: min(target=20, teams_x=10) = 10.
+// CHECK:      @[[EXEC_MODE3:.*]] = weak protected constant i8 1
+// CHECK:      @llvm.compiler.used{{.*}} = appending global [1 x ptr] [ptr @[[EXEC_MODE3]]], section "llvm.metadata"
+// CHECK:      @[[KERNEL3_ENV:.*_kernel_environment]] = weak_odr protected constant %struct.KernelEnvironmentTy {
+// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE3:1]], i32 [[MIN_THREADS3:1]], i32 [[MAX_THREADS3:10]], i32 0, i32 0, i32 0, i32 0 },
+// CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
+
 module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
   llvm.func @main(%num_teams : !llvm.ptr) {
     // CHECK: define weak_odr protected amdgpu_kernel void @__omp_offloading_{{.*}}_main_l{{[0-9]+}}(ptr %[[NUM_TEAMS_ARG:.*]], ptr %[[KERNEL_ARGS:.*]]) #[[ATTRS1:[0-9]+]]
@@ -37,6 +44,20 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
       }
       omp.terminator
     }
+
+    // Multi-dim thread_limit: first dim constant, second dim constant.
+    // MaxThreads uses the first dim combined value: min(20, 10) = 10.
+    // CHECK: define weak_odr protected amdgpu_kernel void @__omp_offloading_{{.*}}_main_l{{[0-9]+}}(ptr %[[KERNEL_ARGS:.*]]) #[[ATTRS1]]
+    // CHECK: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL3_ENV]], ptr %[[KERNEL_ARGS]])
+    %target_threads3 = llvm.mlir.constant(20) : i32
+    omp.target thread_limit(%target_threads3 : i32) {
+      %teams_threads_x = llvm.mlir.constant(10) : i32
+      %teams_threads_y = llvm.mlir.constant(5) : i32
+      omp.teams thread_limit(%teams_threads_x, %teams_threads_y : i32, i32) {
+        omp.terminator
+      }
+      omp.terminator
+    }
     llvm.return
   }
 }
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
index a8b8696190604..87968c683f997 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
@@ -2,18 +2,26 @@
 
 // CHECK: define void @main(i32 %[[NUM_TEAMS_ARG:.*]])
 // CHECK: %[[KERNEL_ARGS:.*]] = alloca %struct.__tgt_kernel_arguments
-// CHECK: insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
+// CHECK: %[[NUM_TEAMS:.*]] = insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
 
 // CHECK: %[[NUM_TEAMS_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 10
-// CHECK-NEXT: store [3 x i32] %{{.*}}, ptr %[[NUM_TEAMS_KARG]], align 4
+// CHECK: store [3 x i32] %[[NUM_TEAMS]], ptr %[[NUM_TEAMS_KARG]], align 4
 
 // CHECK: %[[NUM_THREADS_ARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 11
-// CHECK-NEXT: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
+// CHECK: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
 
 // CHECK: %{{.*}} = call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 %[[NUM_TEAMS_ARG]], i32 [[NUM_THREADS:10]], ptr @.[[OUTLINED_FN:.*]].region_id, ptr %[[KERNEL_ARGS]])
 // CHECK: call void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_ARG]], ptr null)
 
-// CHECK: define internal void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_OUTLINED:.*]], ptr %{{.*}})
+// Multi-dim thread_limit: first dim is constant (10), second dim is runtime variable.
+// The NumThreads [3 x i32] array should have dim0=10, dim1=%thread_limit_y, dim2=0.
+// CHECK: define void @main_multidim_thread_limit(i32 %[[TL_Y:.*]])
+// CHECK: %[[KERNEL_ARGS2:.*]] = alloca %struct.__tgt_kernel_arguments
+// CHECK: %[[NT_ARR:.*]] = insertvalue [3 x i32] [i32 10, i32 0, i32 0], i32 %[[TL_Y]], 1
+// CHECK: %[[NT_GEP:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS2]], i32 0, i32 11
+// CHECK-NEXT: store [3 x i32] %[[NT_ARR]], ptr %[[NT_GEP]], align 4
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 0, i32 10, ptr @.{{.*}}.region_id, ptr %[[KERNEL_ARGS2]])
+// CHECK: define internal void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_OUTLINED:.*]])
 // CHECK: call void @__kmpc_push_num_teams_51(ptr {{.*}}, i32 {{.*}}, i32 %[[NUM_TEAMS_OUTLINED]], i32 %[[NUM_TEAMS_OUTLINED]], i32 [[NUM_THREADS]])
 module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
   llvm.func @main(%num_teams : i32) {
@@ -28,4 +36,15 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
     }
     llvm.return
   }
+
+  llvm.func @main_multidim_thread_limit(%thread_limit_y : i32) {
+    %teams_threads_x = llvm.mlir.constant(10) : i32
+    omp.target host_eval(%teams_threads_x -> %arg_tlx, %thread_limit_y -> %arg_tly : i32, i32) {
+      omp.teams thread_limit(%arg_tlx, %arg_tly : i32, i32) {
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
 }

>From d810b8cd1443b67557abd4c53d61026cd454bce0 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 16 Apr 2026 15:57:40 +0530
Subject: [PATCH 3/3] update 2

---
 .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 ++++++++--------
 .../LLVMIR/openmp-target-launch-device.mlir      |  5 ++---
 .../Target/LLVMIR/openmp-target-launch-host.mlir |  2 +-
 mlir/test/Target/LLVMIR/openmp-teams.mlir        |  2 --
 4 files changed, 11 insertions(+), 14 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 20b757326ac55..a713df53bff19 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6524,13 +6524,13 @@ extractHostEvalClauses(omp::TargetOp targetOp,
     for (Operation *user : blockArg.getUsers()) {
       llvm::TypeSwitch<Operation *>(user)
           .Case([&](omp::TeamsOp teamsOp) {
-            if (teamsOp.getNumTeamsLower() == blockArg)
+            if (teamsOp.getNumTeamsLower() == blockArg) {
               numTeamsLower = hostEvalVar;
-            else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
-                                        blockArg))
+            } else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
+                                          blockArg)) {
               numTeamsUpper = hostEvalVar;
-            else if (llvm::is_contained(teamsOp.getThreadLimitVars(),
-                                        blockArg)) {
+            } else if (llvm::is_contained(teamsOp.getThreadLimitVars(),
+                                          blockArg)) {
               for (auto [i, limitVar] :
                    llvm::enumerate(teamsOp.getThreadLimitVars())) {
                 if (limitVar == blockArg) {
@@ -6540,8 +6540,9 @@ extractHostEvalClauses(omp::TargetOp targetOp,
                   break;
                 }
               }
-            } else
+            } else {
               llvm_unreachable("unsupported host_eval use");
+            }
           })
           .Case([&](omp::ParallelOp parallelOp) {
             if (llvm::is_contained(parallelOp.getNumThreadsVars(), blockArg)) {
@@ -6859,8 +6860,7 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
                     });
   }
 
-  // Ensure TargetThreadLimit and TeamsThreadLimit have matching sizes
-  // for zip_equal in OMPIRBuilder.
+  // Ensure TargetThreadLimit and TeamsThreadLimit have matching sizes.
   size_t maxDims =
       std::max(attrs.TargetThreadLimit.size(), attrs.TeamsThreadLimit.size());
   attrs.TargetThreadLimit.resize(maxDims);
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
index 8fbe6dbb78e9b..3e55f8a546d20 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
@@ -12,7 +12,8 @@
 // CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE2:1]], i32 [[MIN_THREADS2:1]], i32 [[MAX_THREADS2:30]], i32 [[MIN_TEAMS2:40]], i32 [[MAX_TEAMS2:40]], i32 0, i32 0 },
 // CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
 
-// Multi-dim thread_limit: min(target=20, teams_x=10) = 10.
+// Multi-dim thread_limit: first dim constant (10), second dim constant (5).
+// MaxThreads uses the first dim combined value: min(target=20, teams_x=10) = 10.
 // CHECK:      @[[EXEC_MODE3:.*]] = weak protected constant i8 1
 // CHECK:      @llvm.compiler.used{{.*}} = appending global [1 x ptr] [ptr @[[EXEC_MODE3]]], section "llvm.metadata"
 // CHECK:      @[[KERNEL3_ENV:.*_kernel_environment]] = weak_odr protected constant %struct.KernelEnvironmentTy {
@@ -45,8 +46,6 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
       omp.terminator
     }
 
-    // Multi-dim thread_limit: first dim constant, second dim constant.
-    // MaxThreads uses the first dim combined value: min(20, 10) = 10.
     // CHECK: define weak_odr protected amdgpu_kernel void @__omp_offloading_{{.*}}_main_l{{[0-9]+}}(ptr %[[KERNEL_ARGS:.*]]) #[[ATTRS1]]
     // CHECK: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL3_ENV]], ptr %[[KERNEL_ARGS]])
     %target_threads3 = llvm.mlir.constant(20) : i32
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
index 87968c683f997..4096f8e25182c 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
@@ -21,7 +21,7 @@
 // CHECK: %[[NT_GEP:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS2]], i32 0, i32 11
 // CHECK-NEXT: store [3 x i32] %[[NT_ARR]], ptr %[[NT_GEP]], align 4
 // CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 0, i32 10, ptr @.{{.*}}.region_id, ptr %[[KERNEL_ARGS2]])
-// CHECK: define internal void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_OUTLINED:.*]])
+// CHECK: define internal void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_OUTLINED:.*]], ptr %{{.*}})
 // CHECK: call void @__kmpc_push_num_teams_51(ptr {{.*}}, i32 {{.*}}, i32 %[[NUM_TEAMS_OUTLINED]], i32 %[[NUM_TEAMS_OUTLINED]], i32 [[NUM_THREADS]])
 module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
   llvm.func @main(%num_teams : i32) {
diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir
index adca15d1c5fcc..126d3e652a6e1 100644
--- a/mlir/test/Target/LLVMIR/openmp-teams.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir
@@ -319,7 +319,6 @@ llvm.func @duringTeams()
 // CHECK-LABEL: @omp_teams_thread_limit_2d
 // CHECK-SAME: (i32 [[LIMIT_X:.+]], i32 [[LIMIT_Y:.+]])
 llvm.func @omp_teams_thread_limit_2d(%limitX: i32, %limitY: i32) {
-    // Multi-dimensional thread_limit: all dimensions are passed
     // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
     // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[LIMIT_X]])
     // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @{{[0-9]+}}, i32 0, ptr [[OUTLINED_FN:.+]])
@@ -337,7 +336,6 @@ llvm.func @duringTeams()
 // CHECK-LABEL: @omp_teams_thread_limit_3d
 // CHECK-SAME: (i32 [[LIMIT_X:.+]], i64 [[LIMIT_Y:.+]], i16 [[LIMIT_Z:.+]])
 llvm.func @omp_teams_thread_limit_3d(%limitX: i32, %limitY: i64, %limitZ: i16) {
-    // Multi-dimensional thread_limit with mixed types: all dimensions are passed
     // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
     // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[LIMIT_X]])
     // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @{{[0-9]+}}, i32 0, ptr [[OUTLINED_FN:.+]])



More information about the llvm-branch-commits mailing list