[Mlir-commits] [mlir] [OpenMP][MLIR] Add num_teams mlir to llvm lowering (PR #179418)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Apr 15 22:51:37 PDT 2026


https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/179418

>From 3a85815c8dcac11bf9ed49a996aa15e3f68437bc Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 3 Feb 2026 14:47:58 +0530
Subject: [PATCH 1/3] [OpenMP][MLIR] Add num_teams mlir to llvm lowering

---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 117 +++++++++++++-----
 .../LLVMIR/openmp-target-launch-host.mlir     |   6 +-
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  20 ++-
 3 files changed, 104 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2e15f4de4545d..c225c58add282 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -374,8 +374,28 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       result = todo("task_reduction");
   };
   auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
-    if (op.hasNumTeamsMultiDim())
-      result = todo("num_teams with multi-dimensional values");
+    if (op.getNumTeamsDimsCount() > 3) {
+      result = todo("num_teams with more than 3 dimensions");
+      return;
+    }
+
+    // Multi-dimensional num_teams is only fully supported within target
+    // regions.
+    if (op.hasNumTeamsMultiDim()) {
+      Operation *parent = op.getOperation()->getParentOp();
+      bool insideTarget = false;
+      while (parent) {
+        if (isa<omp::TargetOp>(parent)) {
+          insideTarget = true;
+          break;
+        }
+        parent = parent->getParentOp();
+      }
+
+      if (!insideTarget)
+        result = todo(
+            "num_teams with multi-dimensional values outside target region");
+    }
   };
   auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
     if (op.hasNumThreadsMultiDim())
@@ -6501,13 +6521,12 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
 ///
 /// Loop bounds and steps are only optionally populated, if output vectors are
 /// provided.
-static void
-extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
-                       Value &numTeamsLower, Value &numTeamsUpper,
-                       Value &threadLimit,
-                       llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
-                       llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
-                       llvm::SmallVectorImpl<Value> *steps = nullptr) {
+static void extractHostEvalClauses(
+    omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower,
+    llvm::SmallVectorImpl<Value> &numTeamsUpperVars, Value &threadLimit,
+    llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
+    llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
+    llvm::SmallVectorImpl<Value> *steps = nullptr) {
   auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
   for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
                                    blockArgIface.getHostEvalBlockArgs())) {
@@ -6519,10 +6538,19 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
             if (teamsOp.getNumTeamsLower() == blockArg)
               numTeamsLower = hostEvalVar;
             else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
-                                        blockArg))
-              numTeamsUpper = hostEvalVar;
-            else if (!teamsOp.getThreadLimitVars().empty() &&
-                     teamsOp.getThreadLimit(0) == blockArg)
+                                        blockArg)) {
+              // Find which dimension this blockArg corresponds to
+              for (auto [i, upperVar] :
+                   llvm::enumerate(teamsOp.getNumTeamsUpperVars())) {
+                if (upperVar == blockArg) {
+                  if (numTeamsUpperVars.size() <= i)
+                    numTeamsUpperVars.resize(i + 1);
+                  numTeamsUpperVars[i] = hostEvalVar;
+                  break;
+                }
+              }
+            } else if (!teamsOp.getThreadLimitVars().empty() &&
+                       teamsOp.getThreadLimit(0) == blockArg)
               threadLimit = hostEvalVar;
             else
               llvm_unreachable("unsupported host_eval use");
@@ -6639,19 +6667,22 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
                        bool isTargetDevice, bool isGPU) {
   // TODO: Handle constant 'if' clauses.
 
-  Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+  Value numThreads, numTeamsLower, threadLimit;
+  llvm::SmallVector<Value> numTeamsUpperVars;
   if (!isTargetDevice) {
-    extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
-                           threadLimit);
+    extractHostEvalClauses(targetOp, numThreads, numTeamsLower,
+                           numTeamsUpperVars, threadLimit);
   } else {
     // In the target device, values for these clauses are not passed as
     // host_eval, but instead evaluated prior to entry to the region. This
     // ensures values are mapped and available inside of the target region.
     if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
       numTeamsLower = teamsOp.getNumTeamsLower();
-      // Handle num_teams upper bounds (only first value for now)
-      if (!teamsOp.getNumTeamsUpperVars().empty())
-        numTeamsUpper = teamsOp.getNumTeams(0);
+      // Handle all num_teams upper bound dimensions
+      numTeamsUpperVars.reserve(teamsOp.getNumTeamsUpperVars().size());
+      for (auto upperVar : teamsOp.getNumTeamsUpperVars())
+        numTeamsUpperVars.push_back(upperVar);
+      // Handle thread_limit (only first value for now)
       if (!teamsOp.getThreadLimitVars().empty())
         threadLimit = teamsOp.getThreadLimit(0);
     }
@@ -6664,23 +6695,30 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
 
   // Handle clauses impacting the number of teams.
 
-  int32_t minTeamsVal = 1, maxTeamsVal = -1;
+  int32_t minTeamsVal = 1;
+  llvm::SmallVector<int32_t, 3> maxTeamsVals(3, -1);
   if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
-    // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
+    // TODO: Use `numTeamsLower` to initialize `minTeamsVal`. For now,
     // match clang and set min and max to the same value.
-    if (numTeamsUpper) {
-      if (auto val = extractConstInteger(numTeamsUpper))
-        minTeamsVal = maxTeamsVal = *val;
-    } else {
-      minTeamsVal = maxTeamsVal = 0;
+    if (!numTeamsUpperVars.empty()) {
+      // Handle multi-dimensional num_teams
+      for (auto [i, upperVar] : llvm::enumerate(numTeamsUpperVars)) {
+        if (upperVar) {
+          if (auto val = extractConstInteger(upperVar)) {
+            maxTeamsVals[i] = *val;
+            if (i == 0)
+              minTeamsVal = *val;
+          }
+        }
+      }
     }
   } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
                                                     /*immediateParent=*/true) ||
              castOrGetParentOfType<omp::SimdOp>(capturedOp,
                                                 /*immediateParent=*/true)) {
-    minTeamsVal = maxTeamsVal = 1;
+    minTeamsVal = maxTeamsVals[0] = 1;
   } else {
-    minTeamsVal = maxTeamsVal = -1;
+    minTeamsVal = maxTeamsVals[0] = -1;
   }
 
   // Handle clauses impacting the number of threads.
@@ -6749,7 +6787,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
 
   attrs.MinTeams = minTeamsVal;
-  attrs.MaxTeams.front() = maxTeamsVal;
+  // Always resize to 3 dimensions to match TargetKernelRuntimeAttrs
+  attrs.MaxTeams.resize(3, -1);
+  for (size_t i = 0; i < maxTeamsVals.size() && i < attrs.MaxTeams.size(); ++i)
+    attrs.MaxTeams[i] = maxTeamsVals[i];
   attrs.MinThreads = 1;
   attrs.MaxThreads.front() = combinedMaxThreadsVal;
   attrs.ReductionDataSize = reductionDataSize;
@@ -6773,10 +6814,11 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
   unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
 
-  Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+  Value numThreads, numTeamsLower, teamsThreadLimit;
+  llvm::SmallVector<Value> numTeamsUpperVars;
   llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
       steps(numLoops);
-  extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
+  extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpperVars,
                          teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
 
   // TODO: Handle constant 'if' clauses.
@@ -6793,9 +6835,16 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
     attrs.MinTeams = builder.CreateSExtOrTrunc(
         moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
 
-  if (numTeamsUpper)
-    attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
-        moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
+  // Handle multi-dimensional num_teams upper bounds
+  attrs.MaxTeams.resize(3);
+  if (!numTeamsUpperVars.empty()) {
+    for (auto [i, upperVar] : llvm::enumerate(numTeamsUpperVars)) {
+      if (upperVar) {
+        attrs.MaxTeams[i] = builder.CreateSExtOrTrunc(
+            moduleTranslation.lookupValue(upperVar), builder.getInt32Ty());
+      }
+    }
+  }
 
   if (teamsThreadLimit)
     attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
index deb1e6cef50bd..a8b8696190604 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
@@ -2,13 +2,13 @@
 
 // CHECK: define void @main(i32 %[[NUM_TEAMS_ARG:.*]])
 // CHECK: %[[KERNEL_ARGS:.*]] = alloca %struct.__tgt_kernel_arguments
-// CHECK: %[[NUM_TEAMS:.*]] = insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
+// CHECK: insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
 
 // CHECK: %[[NUM_TEAMS_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 10
-// CHECK: store [3 x i32] %[[NUM_TEAMS]], ptr %[[NUM_TEAMS_KARG]], align 4
+// CHECK-NEXT: store [3 x i32] %{{.*}}, ptr %[[NUM_TEAMS_KARG]], align 4
 
 // CHECK: %[[NUM_THREADS_ARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 11
-// CHECK: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
+// CHECK-NEXT: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
 
 // CHECK: %{{.*}} = call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 %[[NUM_TEAMS_ARG]], i32 [[NUM_THREADS:10]], ptr @.[[OUTLINED_FN:.*]].region_id, ptr %[[KERNEL_ARGS]])
 // CHECK: call void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_ARG]], ptr null)
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index e0872226531e6..7e9e4ad8d07f6 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -446,8 +446,8 @@ llvm.func @teams_private(%x : !llvm.ptr) {
 
 // -----
 
-llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
-  // expected-error at below {{not yet implemented: Unhandled clause num_teams with multi-dimensional values in omp.teams operation}}
+llvm.func @teams_num_teams_multi_dim_standalone(%lb : i32, %ub : i32) {
+  // expected-error at below {{not yet implemented: Unhandled clause num_teams with multi-dimensional values outside target region in omp.teams operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.teams}}
   omp.teams num_teams(to %ub, %ub, %ub : i32, i32, i32) {
     omp.terminator
@@ -457,6 +457,22 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
 
 // -----
 
+llvm.func @teams_num_teams_too_many_dims() {
+  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
+  omp.target {
+    %c100 = llvm.mlir.constant(100 : i32) : i32
+    // expected-error at below {{not yet implemented: Unhandled clause num_teams with more than 3 dimensions in omp.teams operation}}
+    // expected-error at below {{LLVM Translation failed for operation: omp.teams}}
+    omp.teams num_teams(to %c100, %c100, %c100, %c100 : i32, i32, i32, i32) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
   // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}

>From 178ef935b7a0479ff0922028efbe334b3b4382c2 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 9 Apr 2026 17:08:54 +0530
Subject: [PATCH 2/3] update

---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 70 ++++++++-----------
 .../LLVMIR/openmp-target-launch-device.mlir   | 19 +++++
 .../LLVMIR/openmp-target-launch-host.mlir     | 28 ++++++--
 3 files changed, 73 insertions(+), 44 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index c225c58add282..4e60e62f7f59f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -379,23 +379,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       return;
     }
 
-    // Multi-dimensional num_teams is only fully supported within target
-    // regions.
-    if (op.hasNumTeamsMultiDim()) {
-      Operation *parent = op.getOperation()->getParentOp();
-      bool insideTarget = false;
-      while (parent) {
-        if (isa<omp::TargetOp>(parent)) {
-          insideTarget = true;
-          break;
-        }
-        parent = parent->getParentOp();
-      }
-
-      if (!insideTarget)
-        result = todo(
-            "num_teams with multi-dimensional values outside target region");
-    }
+    if (op.hasNumTeamsMultiDim() &&
+        !isa_and_present<omp::TargetOp>(op->getParentOp()))
+      result =
+          todo("num_teams with multi-dimensional values outside target region");
   };
   auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
     if (op.hasNumThreadsMultiDim())
@@ -6535,10 +6522,10 @@ static void extractHostEvalClauses(
     for (Operation *user : blockArg.getUsers()) {
       llvm::TypeSwitch<Operation *>(user)
           .Case([&](omp::TeamsOp teamsOp) {
-            if (teamsOp.getNumTeamsLower() == blockArg)
+            if (teamsOp.getNumTeamsLower() == blockArg) {
               numTeamsLower = hostEvalVar;
-            else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
-                                        blockArg)) {
+            } else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
+                                          blockArg)) {
               // Find which dimension this blockArg corresponds to
               for (auto [i, upperVar] :
                    llvm::enumerate(teamsOp.getNumTeamsUpperVars())) {
@@ -6550,10 +6537,11 @@ static void extractHostEvalClauses(
                 }
               }
             } else if (!teamsOp.getThreadLimitVars().empty() &&
-                       teamsOp.getThreadLimit(0) == blockArg)
+                       teamsOp.getThreadLimit(0) == blockArg) {
               threadLimit = hostEvalVar;
-            else
+            } else {
               llvm_unreachable("unsupported host_eval use");
+            }
           })
           .Case([&](omp::ParallelOp parallelOp) {
             if (!parallelOp.getNumThreadsVars().empty() &&
@@ -6696,21 +6684,28 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
   // Handle clauses impacting the number of teams.
 
   int32_t minTeamsVal = 1;
-  llvm::SmallVector<int32_t, 3> maxTeamsVals(3, -1);
+  llvm::SmallVector<int32_t, 3> maxTeamsVals(
+      std::max(numTeamsUpperVars.size(), static_cast<size_t>(1)), -1);
   if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
     // TODO: Use `numTeamsLower` to initialize `minTeamsVal`. For now,
     // match clang and set min and max to the same value.
     if (!numTeamsUpperVars.empty()) {
-      // Handle multi-dimensional num_teams
       for (auto [i, upperVar] : llvm::enumerate(numTeamsUpperVars)) {
         if (upperVar) {
-          if (auto val = extractConstInteger(upperVar)) {
+          if (auto val = extractConstInteger(upperVar))
             maxTeamsVals[i] = *val;
-            if (i == 0)
-              minTeamsVal = *val;
-          }
         }
       }
+      // minTeamsVal is a single scalar and only meaningful for the
+      // unidimensional case. Per the spec, lower-bound may not be
+      // specified when the dims modifier is specified, and when unspecified
+      // it equals the upper bound. In the multidimensional case,
+      // maxTeamsVals should be used as both lower and upper bounds for each
+      // dimension.
+      if (maxTeamsVals[0] >= 0)
+        minTeamsVal = maxTeamsVals[0];
+    } else {
+      minTeamsVal = maxTeamsVals[0] = 0;
     }
   } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
                                                     /*immediateParent=*/true) ||
@@ -6787,10 +6782,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
 
   attrs.MinTeams = minTeamsVal;
-  // Always resize to 3 dimensions to match TargetKernelRuntimeAttrs
-  attrs.MaxTeams.resize(3, -1);
-  for (size_t i = 0; i < maxTeamsVals.size() && i < attrs.MaxTeams.size(); ++i)
-    attrs.MaxTeams[i] = maxTeamsVals[i];
+  attrs.MaxTeams = maxTeamsVals;
   attrs.MinThreads = 1;
   attrs.MaxThreads.front() = combinedMaxThreadsVal;
   attrs.ReductionDataSize = reductionDataSize;
@@ -6835,14 +6827,12 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
     attrs.MinTeams = builder.CreateSExtOrTrunc(
         moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
 
-  // Handle multi-dimensional num_teams upper bounds
-  attrs.MaxTeams.resize(3);
-  if (!numTeamsUpperVars.empty()) {
-    for (auto [i, upperVar] : llvm::enumerate(numTeamsUpperVars)) {
-      if (upperVar) {
-        attrs.MaxTeams[i] = builder.CreateSExtOrTrunc(
-            moduleTranslation.lookupValue(upperVar), builder.getInt32Ty());
-      }
+  attrs.MaxTeams.resize(
+      std::max(numTeamsUpperVars.size(), static_cast<size_t>(1)));
+  for (auto [i, upperVar] : llvm::enumerate(numTeamsUpperVars)) {
+    if (upperVar) {
+      attrs.MaxTeams[i] = builder.CreateSExtOrTrunc(
+          moduleTranslation.lookupValue(upperVar), builder.getInt32Ty());
     }
   }
 
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
index e27f7fe4b2e7e..a2156a79702e1 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
@@ -12,6 +12,12 @@
 // CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE2:1]], i32 [[MIN_THREADS2:1]], i32 [[MAX_THREADS2:30]], i32 [[MIN_TEAMS2:40]], i32 [[MAX_TEAMS2:40]], i32 0, i32 0 },
 // CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
 
+// CHECK:      @[[EXEC_MODE3:.*]] = weak protected constant i8 1
+// CHECK:      @llvm.compiler.used{{.*}} = appending global [1 x ptr] [ptr @[[EXEC_MODE3]]], section "llvm.metadata"
+// CHECK:      @[[KERNEL3_ENV:.*_kernel_environment]] = weak_odr protected constant %struct.KernelEnvironmentTy {
+// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE3:1]], i32 [[MIN_THREADS3:1]], i32 [[MAX_THREADS3:20]], i32 [[MIN_TEAMS3:10]], i32 [[MAX_TEAMS3:10]], i32 0, i32 0 },
+// CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
+
 module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memory_space", 5 : ui32>>, llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_target_device = true, omp.is_gpu = true} {
   llvm.func @main(%num_teams : !llvm.ptr) {
     // CHECK: define weak_odr protected amdgpu_kernel void @__omp_offloading_{{.*}}_main_l{{[0-9]+}}(ptr %[[NUM_TEAMS_ARG:.*]], ptr %[[KERNEL_ARGS:.*]]) #[[ATTRS1:[0-9]+]]
@@ -37,9 +43,22 @@ module attributes {dlti.dl_spec = #dlti.dl_spec<#dlti.dl_entry<"dlti.alloca_memo
       }
       omp.terminator
     }
+
+    // CHECK: define weak_odr protected amdgpu_kernel void @__omp_offloading_{{.*}}_main_l{{[0-9]+}}(ptr %[[KERNEL_ARGS:.*]]) #[[ATTRS3:[0-9]+]]
+    // CHECK: %{{.*}} = call i32 @__kmpc_target_init(ptr @[[KERNEL3_ENV]], ptr %[[KERNEL_ARGS]])
+    %target_threads3 = llvm.mlir.constant(20) : i32
+    omp.target thread_limit(%target_threads3 : i32) {
+      %nt_x = llvm.mlir.constant(10) : i32
+      %nt_y = llvm.mlir.constant(5) : i32
+      omp.teams num_teams(to %nt_x, %nt_y : i32, i32) {
+        omp.terminator
+      }
+      omp.terminator
+    }
     llvm.return
   }
 }
 
 // CHECK: attributes #[[ATTRS1]] = { "amdgpu-flat-work-group-size"="[[MIN_THREADS1]],[[MAX_THREADS1]]" "omp_target_thread_limit"="[[MAX_THREADS1]]" }
 // CHECK: attributes #[[ATTRS2]] = { "amdgpu-flat-work-group-size"="[[MIN_THREADS2]],[[MAX_THREADS2]]" "amdgpu-max-num-workgroups"="[[MIN_TEAMS2]],1,1" "omp_target_num_teams"="[[MIN_TEAMS2]]" "omp_target_thread_limit"="[[MAX_THREADS2]]" }
+// CHECK: attributes #[[ATTRS3]] = { "amdgpu-flat-work-group-size"="[[MIN_THREADS3]],[[MAX_THREADS3]]" "amdgpu-max-num-workgroups"="[[MIN_TEAMS3]],1,1" "omp_target_num_teams"="[[MIN_TEAMS3]]" "omp_target_thread_limit"="[[MAX_THREADS3]]" }
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
index a8b8696190604..076c0a18d05ab 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
@@ -2,18 +2,27 @@
 
 // CHECK: define void @main(i32 %[[NUM_TEAMS_ARG:.*]])
 // CHECK: %[[KERNEL_ARGS:.*]] = alloca %struct.__tgt_kernel_arguments
-// CHECK: insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
+// CHECK: %[[NUM_TEAMS:.*]] = insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
 
 // CHECK: %[[NUM_TEAMS_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 10
-// CHECK-NEXT: store [3 x i32] %{{.*}}, ptr %[[NUM_TEAMS_KARG]], align 4
+// CHECK: store [3 x i32] %[[NUM_TEAMS]], ptr %[[NUM_TEAMS_KARG]], align 4
 
 // CHECK: %[[NUM_THREADS_ARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 11
-// CHECK-NEXT: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
+// CHECK: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
 
 // CHECK: %{{.*}} = call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 %[[NUM_TEAMS_ARG]], i32 [[NUM_THREADS:10]], ptr @.[[OUTLINED_FN:.*]].region_id, ptr %[[KERNEL_ARGS]])
 // CHECK: call void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_ARG]], ptr null)
 
-// CHECK: define internal void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_OUTLINED:.*]], ptr %{{.*}})
+// Multi-dim num_teams: first dim is constant (20), second dim is runtime variable.
+// The NumTeams [3 x i32] array should have dim0=20, dim1=%num_teams_y, dim2=0.
+// CHECK: define void @main_multidim_num_teams(i32 %[[NT_Y:.*]])
+// CHECK: %[[KERNEL_ARGS2:.*]] = alloca %struct.__tgt_kernel_arguments
+// CHECK: %[[NT_ARR:.*]] = insertvalue [3 x i32] [i32 20, i32 0, i32 0], i32 %[[NT_Y]], 1
+// CHECK: %[[NT_GEP:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS2]], i32 0, i32 10
+// CHECK-NEXT: store [3 x i32] %[[NT_ARR]], ptr %[[NT_GEP]], align 4
+// CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 20, i32 0, ptr @.{{.*}}.region_id, ptr %[[KERNEL_ARGS2]])
+
+// CHECK: define internal void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_OUTLINED:.*]])
 // CHECK: call void @__kmpc_push_num_teams_51(ptr {{.*}}, i32 {{.*}}, i32 %[[NUM_TEAMS_OUTLINED]], i32 %[[NUM_TEAMS_OUTLINED]], i32 [[NUM_THREADS]])
 module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
   llvm.func @main(%num_teams : i32) {
@@ -28,4 +37,15 @@ module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-a
     }
     llvm.return
   }
+
+  llvm.func @main_multidim_num_teams(%num_teams_y : i32) {
+    %num_teams_x = llvm.mlir.constant(20) : i32
+    omp.target host_eval(%num_teams_x -> %arg_ntx, %num_teams_y -> %arg_nty : i32, i32) {
+      omp.teams num_teams(to %arg_ntx, %arg_nty : i32, i32) {
+        omp.terminator
+      }
+      omp.terminator
+    }
+    llvm.return
+  }
 }

>From eea8fb3858522ffefec168ae2ada8c3e624a29e7 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 16 Apr 2026 11:11:01 +0530
Subject: [PATCH 3/3] add else block for maxTeamsVals

---
 .../LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp   | 8 +++++---
 mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir  | 2 +-
 mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir    | 2 +-
 3 files changed, 7 insertions(+), 5 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 4e60e62f7f59f..d45d51b5de642 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6526,7 +6526,7 @@ static void extractHostEvalClauses(
               numTeamsLower = hostEvalVar;
             } else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
                                           blockArg)) {
-              // Find which dimension this blockArg corresponds to
+              // Find which dimension this blockArg corresponds to.
               for (auto [i, upperVar] :
                    llvm::enumerate(teamsOp.getNumTeamsUpperVars())) {
                 if (upperVar == blockArg) {
@@ -6666,11 +6666,11 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     // ensures values are mapped and available inside of the target region.
     if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
       numTeamsLower = teamsOp.getNumTeamsLower();
-      // Handle all num_teams upper bound dimensions
+      // Handle all num_teams upper bound dimensions.
       numTeamsUpperVars.reserve(teamsOp.getNumTeamsUpperVars().size());
       for (auto upperVar : teamsOp.getNumTeamsUpperVars())
         numTeamsUpperVars.push_back(upperVar);
-      // Handle thread_limit (only first value for now)
+      // Handle thread_limit (only first value for now).
       if (!teamsOp.getThreadLimitVars().empty())
         threadLimit = teamsOp.getThreadLimit(0);
     }
@@ -6694,6 +6694,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
         if (upperVar) {
           if (auto val = extractConstInteger(upperVar))
             maxTeamsVals[i] = *val;
+          else
+            maxTeamsVals[i] = 0;
         }
       }
       // minTeamsVal is a single scalar and only meaningful for the
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
index a2156a79702e1..a6e557cb3dbfa 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-device.mlir
@@ -3,7 +3,7 @@
 // CHECK:      @[[EXEC_MODE1:.*]] = weak protected constant i8 1
 // CHECK:      @llvm.compiler.used{{.*}} = appending global [1 x ptr] [ptr @[[EXEC_MODE1]]], section "llvm.metadata"
 // CHECK:      @[[KERNEL1_ENV:.*_kernel_environment]] = weak_odr protected constant %struct.KernelEnvironmentTy {
-// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE1:1]], i32 [[MIN_THREADS1:1]], i32 [[MAX_THREADS1:10]], i32 [[MIN_TEAMS1:1]], i32 [[MAX_TEAMS1:-1]], i32 0, i32 0 },
+// CHECK-SAME: %struct.ConfigurationEnvironmentTy { i8 1, i8 1, i8 [[EXEC_MODE1:1]], i32 [[MIN_THREADS1:1]], i32 [[MAX_THREADS1:10]], i32 [[MIN_TEAMS1:0]], i32 [[MAX_TEAMS1:0]], i32 0, i32 0 },
 // CHECK-SAME: ptr @{{.*}}, ptr @{{.*}} }
 
 // CHECK:      @[[EXEC_MODE2:.*]] = weak protected constant i8 1
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
index 076c0a18d05ab..6467ab0993b1e 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
@@ -22,7 +22,7 @@
 // CHECK-NEXT: store [3 x i32] %[[NT_ARR]], ptr %[[NT_GEP]], align 4
 // CHECK: call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 20, i32 0, ptr @.{{.*}}.region_id, ptr %[[KERNEL_ARGS2]])
 
-// CHECK: define internal void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_OUTLINED:.*]])
+// CHECK: define internal void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_OUTLINED:.*]], ptr %{{.*}})
 // CHECK: call void @__kmpc_push_num_teams_51(ptr {{.*}}, i32 {{.*}}, i32 %[[NUM_TEAMS_OUTLINED]], i32 %[[NUM_TEAMS_OUTLINED]], i32 [[NUM_THREADS]])
 module attributes {omp.is_target_device = false, omp.target_triples = ["amdgcn-amd-amdhsa"]} {
   llvm.func @main(%num_teams : i32) {



More information about the Mlir-commits mailing list