[Mlir-commits] [mlir] [OpenMP][MLIR] Add num_teams mlir to llvm lowering (PR #179418)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Feb 3 01:19:55 PST 2026


https://github.com/skc7 created https://github.com/llvm/llvm-project/pull/179418

None

>From 189c8e4851f1456723a8873381e09b62b8d17f91 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 3 Feb 2026 14:47:58 +0530
Subject: [PATCH] [OpenMP][MLIR] Add num_teams mlir to llvm lowering

---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 114 ++++++++++++------
 .../LLVMIR/openmp-target-launch-host.mlir     |   6 +-
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  20 ++-
 3 files changed, 101 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 022322502a755..2d2d4e1abf33c 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -377,8 +377,28 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       result = todo("task_reduction");
   };
   auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
-    if (op.hasNumTeamsMultiDim())
-      result = todo("num_teams with multi-dimensional values");
+    if (op.getNumTeamsDimsCount() > 3) {
+      result = todo("num_teams with more than 3 dimensions");
+      return;
+    }
+
+    // Multi-dimensional num_teams is only fully supported within target
+    // regions.
+    if (op.hasNumTeamsMultiDim()) {
+      Operation *parent = op.getOperation()->getParentOp();
+      bool insideTarget = false;
+      while (parent) {
+        if (isa<omp::TargetOp>(parent)) {
+          insideTarget = true;
+          break;
+        }
+        parent = parent->getParentOp();
+      }
+
+      if (!insideTarget)
+        result = todo(
+            "num_teams with multi-dimensional values outside target region");
+    }
   };
   auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
     if (op.hasNumThreadsMultiDim())
@@ -6037,13 +6057,12 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
 ///
 /// Loop bounds and steps are only optionally populated, if output vectors are
 /// provided.
-static void
-extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
-                       Value &numTeamsLower, Value &numTeamsUpper,
-                       Value &threadLimit,
-                       llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
-                       llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
-                       llvm::SmallVectorImpl<Value> *steps = nullptr) {
+static void extractHostEvalClauses(
+    omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower,
+    llvm::SmallVectorImpl<Value> &numTeamsUpperVars, Value &threadLimit,
+    llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
+    llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
+    llvm::SmallVectorImpl<Value> *steps = nullptr) {
   auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
   for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
                                    blockArgIface.getHostEvalBlockArgs())) {
@@ -6055,10 +6074,19 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
             if (teamsOp.getNumTeamsLower() == blockArg)
               numTeamsLower = hostEvalVar;
             else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
-                                        blockArg))
-              numTeamsUpper = hostEvalVar;
-            else if (!teamsOp.getThreadLimitVars().empty() &&
-                     teamsOp.getThreadLimit(0) == blockArg)
+                                        blockArg)) {
+              // Find which dimension this blockArg corresponds to
+              for (auto [i, upperVar] :
+                   llvm::enumerate(teamsOp.getNumTeamsUpperVars())) {
+                if (upperVar == blockArg) {
+                  if (numTeamsUpperVars.size() <= i)
+                    numTeamsUpperVars.resize(i + 1);
+                  numTeamsUpperVars[i] = hostEvalVar;
+                  break;
+                }
+              }
+            } else if (!teamsOp.getThreadLimitVars().empty() &&
+                       teamsOp.getThreadLimit(0) == blockArg)
               threadLimit = hostEvalVar;
             else
               llvm_unreachable("unsupported host_eval use");
@@ -6168,19 +6196,22 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
                        bool isTargetDevice, bool isGPU) {
   // TODO: Handle constant 'if' clauses.
 
-  Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+  Value numThreads, numTeamsLower, threadLimit;
+  llvm::SmallVector<Value> numTeamsUpperVars;
   if (!isTargetDevice) {
-    extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
-                           threadLimit);
+    extractHostEvalClauses(targetOp, numThreads, numTeamsLower,
+                           numTeamsUpperVars, threadLimit);
   } else {
     // In the target device, values for these clauses are not passed as
     // host_eval, but instead evaluated prior to entry to the region. This
     // ensures values are mapped and available inside of the target region.
     if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
       numTeamsLower = teamsOp.getNumTeamsLower();
-      // Handle num_teams upper bounds (only first value for now)
-      if (!teamsOp.getNumTeamsUpperVars().empty())
-        numTeamsUpper = teamsOp.getNumTeams(0);
+      // Handle all num_teams upper bound dimensions
+      numTeamsUpperVars.reserve(teamsOp.getNumTeamsUpperVars().size());
+      for (auto upperVar : teamsOp.getNumTeamsUpperVars())
+        numTeamsUpperVars.push_back(upperVar);
+      // Handle thread_limit (only first value for now)
       if (!teamsOp.getThreadLimitVars().empty())
         threadLimit = teamsOp.getThreadLimit(0);
     }
@@ -6193,23 +6224,30 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
 
   // Handle clauses impacting the number of teams.
 
-  int32_t minTeamsVal = 1, maxTeamsVal = -1;
+  int32_t minTeamsVal = 1;
+  llvm::SmallVector<int32_t, 3> maxTeamsVals(3, -1);
   if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
-    // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
+    // TODO: Use `numTeamsLower` to initialize `minTeamsVal`. For now,
     // match clang and set min and max to the same value.
-    if (numTeamsUpper) {
-      if (auto val = extractConstInteger(numTeamsUpper))
-        minTeamsVal = maxTeamsVal = *val;
-    } else {
-      minTeamsVal = maxTeamsVal = 0;
+    if (!numTeamsUpperVars.empty()) {
+      // Handle multi-dimensional num_teams
+      for (auto [i, upperVar] : llvm::enumerate(numTeamsUpperVars)) {
+        if (upperVar) {
+          if (auto val = extractConstInteger(upperVar)) {
+            maxTeamsVals[i] = *val;
+            if (i == 0)
+              minTeamsVal = *val;
+          }
+        }
+      }
     }
   } else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
                                                     /*immediateParent=*/true) ||
              castOrGetParentOfType<omp::SimdOp>(capturedOp,
                                                 /*immediateParent=*/true)) {
-    minTeamsVal = maxTeamsVal = 1;
+    minTeamsVal = maxTeamsVals[0] = 1;
   } else {
-    minTeamsVal = maxTeamsVal = -1;
+    minTeamsVal = maxTeamsVals[0] = -1;
   }
 
   // Handle clauses impacting the number of threads.
@@ -6278,7 +6316,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
 
   attrs.MinTeams = minTeamsVal;
-  attrs.MaxTeams.front() = maxTeamsVal;
+  attrs.MaxTeams = maxTeamsVals;
   attrs.MinThreads = 1;
   attrs.MaxThreads.front() = combinedMaxThreadsVal;
   attrs.ReductionDataSize = reductionDataSize;
@@ -6302,10 +6340,11 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
   unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
 
-  Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+  Value numThreads, numTeamsLower, teamsThreadLimit;
+  llvm::SmallVector<Value> numTeamsUpperVars;
   llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
       steps(numLoops);
-  extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
+  extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpperVars,
                          teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
 
   // TODO: Handle constant 'if' clauses.
@@ -6322,9 +6361,16 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
     attrs.MinTeams = builder.CreateSExtOrTrunc(
         moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
 
-  if (numTeamsUpper)
-    attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
-        moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
+  // Handle multi-dimensional num_teams upper bounds
+  attrs.MaxTeams.resize(3);
+  if (!numTeamsUpperVars.empty()) {
+    for (auto [i, upperVar] : llvm::enumerate(numTeamsUpperVars)) {
+      if (upperVar) {
+        attrs.MaxTeams[i] = builder.CreateSExtOrTrunc(
+            moduleTranslation.lookupValue(upperVar), builder.getInt32Ty());
+      }
+    }
+  }
 
   if (teamsThreadLimit)
     attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
index abc67017b620d..9cc9a5025b613 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
@@ -2,13 +2,13 @@
 
 // CHECK: define void @main(i32 %[[NUM_TEAMS_ARG:.*]])
 // CHECK: %[[KERNEL_ARGS:.*]] = alloca %struct.__tgt_kernel_arguments
-// CHECK: %[[NUM_TEAMS:.*]] = insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
+// CHECK: insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
 
 // CHECK: %[[NUM_TEAMS_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 10
-// CHECK: store [3 x i32] %[[NUM_TEAMS]], ptr %[[NUM_TEAMS_KARG]], align 4
+// CHECK-NEXT: store [3 x i32] %{{.*}}, ptr %[[NUM_TEAMS_KARG]], align 4
 
 // CHECK: %[[NUM_THREADS_ARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 11
-// CHECK: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
+// CHECK-NEXT: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
 
 // CHECK: %{{.*}} = call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 %[[NUM_TEAMS_ARG]], i32 [[NUM_THREADS:10]], ptr @.[[OUTLINED_FN:.*]].region_id, ptr %[[KERNEL_ARGS]])
 // CHECK: call void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_ARG]])
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 36338e5cb1bed..f04c27139df93 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -432,8 +432,8 @@ llvm.func @teams_private(%x : !llvm.ptr) {
 
 // -----
 
-llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
-  // expected-error at below {{not yet implemented: Unhandled clause num_teams with multi-dimensional values in omp.teams operation}}
+llvm.func @teams_num_teams_multi_dim_standalone(%lb : i32, %ub : i32) {
+  // expected-error at below {{not yet implemented: Unhandled clause num_teams with multi-dimensional values outside target region in omp.teams operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.teams}}
   omp.teams num_teams(to %ub, %ub, %ub : i32, i32, i32) {
     omp.terminator
@@ -443,6 +443,22 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
 
 // -----
 
+llvm.func @teams_num_teams_too_many_dims() {
+  // expected-error at below {{LLVM Translation failed for operation: omp.target}}
+  omp.target {
+    %c100 = llvm.mlir.constant(100 : i32) : i32
+    // expected-error at below {{not yet implemented: Unhandled clause num_teams with more than 3 dimensions in omp.teams operation}}
+    // expected-error at below {{LLVM Translation failed for operation: omp.teams}}
+    omp.teams num_teams(to %c100, %c100, %c100, %c100 : i32, i32, i32, i32) {
+      omp.terminator
+    }
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
   // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}



More information about the Mlir-commits mailing list