[Mlir-commits] [mlir] [OpenMP][MLIR] Add num_threads mlir->llvm lowering (PR #179420)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Apr 16 03:33:08 PDT 2026


https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/179420

>From 94a5614319bd9435d3e675d0d8e5113f2611cc84 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 3 Feb 2026 15:29:10 +0530
Subject: [PATCH 1/2] [OpenMP][MLIR] Add num_threads mlir->llvm lowering

---
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 66 +++++++++++--------
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  6 +-
 2 files changed, 43 insertions(+), 29 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2e15f4de4545d..3bb0a181529de 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -378,8 +378,9 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       result = todo("num_teams with multi-dimensional values");
   };
   auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
-    if (op.hasNumThreadsMultiDim())
-      result = todo("num_threads with multi-dimensional values");
+    // Check that we don't exceed the maximum supported dimensions (3)
+    if (op.getNumThreadsDimsCount() > 3)
+      result = todo("num_threads with more than 3 dimensions");
   };
 
   auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
@@ -6501,13 +6502,12 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
 ///
 /// Loop bounds and steps are only optionally populated, if output vectors are
 /// provided.
-static void
-extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
-                       Value &numTeamsLower, Value &numTeamsUpper,
-                       Value &threadLimit,
-                       llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
-                       llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
-                       llvm::SmallVectorImpl<Value> *steps = nullptr) {
+static void extractHostEvalClauses(
+    omp::TargetOp targetOp, llvm::SmallVectorImpl<Value> &numThreadsVars,
+    Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit,
+    llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
+    llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
+    llvm::SmallVectorImpl<Value> *steps = nullptr) {
   auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
   for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
                                    blockArgIface.getHostEvalBlockArgs())) {
@@ -6528,10 +6528,17 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
-            if (!parallelOp.getNumThreadsVars().empty() &&
-                parallelOp.getNumThreads(0) == blockArg)
-              numThreads = hostEvalVar;
-            else
+            if (llvm::is_contained(parallelOp.getNumThreadsVars(), blockArg)) {
+              for (auto [i, threadsVar] :
+                   llvm::enumerate(parallelOp.getNumThreadsVars())) {
+                if (threadsVar == blockArg) {
+                  if (numThreadsVars.size() <= i)
+                    numThreadsVars.resize(i + 1);
+                  numThreadsVars[i] = hostEvalVar;
+                  break;
+                }
+              }
+            } else
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::LoopNestOp loopOp) {
@@ -6639,10 +6646,11 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
                        bool isTargetDevice, bool isGPU) {
   // TODO: Handle constant 'if' clauses.
 
-  Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+  Value numTeamsLower, numTeamsUpper, threadLimit;
+  llvm::SmallVector<Value> numThreadsVars;
   if (!isTargetDevice) {
-    extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
-                           threadLimit);
+    extractHostEvalClauses(targetOp, numThreadsVars, numTeamsLower,
+                           numTeamsUpper, threadLimit);
   } else {
     // In the target device, values for these clauses are not passed as
     // host_eval, but instead evaluated prior to entry to the region. This
@@ -6657,8 +6665,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     }
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
-      if (!parallelOp.getNumThreadsVars().empty())
-        numThreads = parallelOp.getNumThreads(0);
+      // Handle multi-dimensional num_threads
+      numThreadsVars.reserve(parallelOp.getNumThreadsVars().size());
+      for (auto threadsVar : parallelOp.getNumThreadsVars())
+        numThreadsVars.push_back(threadsVar);
     }
   }
 
@@ -6706,10 +6716,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
 
   // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
   int32_t maxThreadsVal = -1;
-  if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
-    setMaxValueFromClause(numThreads, maxThreadsVal);
-  else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
-                                              /*immediateParent=*/true))
+  if (castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+    // For multi-dimensional num_threads, only use the first dimension for now
+    if (!numThreadsVars.empty())
+      setMaxValueFromClause(numThreadsVars[0], maxThreadsVal);
+  } else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
+                                                /*immediateParent=*/true))
     maxThreadsVal = 1;
 
   // For max values, < 0 means unset, == 0 means set but unknown. Select the
@@ -6773,10 +6785,11 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
   omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
   unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
 
-  Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+  Value numTeamsLower, numTeamsUpper, teamsThreadLimit;
+  llvm::SmallVector<Value> numThreadsVars;
   llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
       steps(numLoops);
-  extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
+  extractHostEvalClauses(targetOp, numThreadsVars, numTeamsLower, numTeamsUpper,
                          teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
 
   // TODO: Handle constant 'if' clauses.
@@ -6801,8 +6814,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
     attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
         moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
 
-  if (numThreads)
-    attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
+  // Handle multi-dimensional num_threads (only first value for now)
+  if (!numThreadsVars.empty())
+    attrs.MaxThreads = moduleTranslation.lookupValue(numThreadsVars[0]);
 
   if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
                               omp::TargetRegionFlags::trip_count)) {
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index e0872226531e6..22ca025ac85d6 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -457,10 +457,10 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
 
 // -----
 
-llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
-  // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}}
+llvm.func @parallel_num_threads_too_many_dims(%lb : i32, %ub : i32) {
+  // expected-error at below {{not yet implemented: Unhandled clause num_threads with more than 3 dimensions in omp.parallel operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
-  omp.parallel num_threads(%lb, %ub : i32, i32) {
+  omp.parallel num_threads(%lb, %ub, %lb, %ub : i32, i32, i32, i32) {
     omp.terminator
   }
   llvm.return

>From bbb3f8c7a20cca89801cf00d0250e0c21681aa02 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 10 Apr 2026 11:31:59 +0530
Subject: [PATCH 2/2] add todo test

---
 .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp    | 13 ++++++++++---
 mlir/test/Target/LLVMIR/openmp-todo.mlir            | 11 +++++++++++
 2 files changed, 21 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3bb0a181529de..64c7e5700c771 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -378,9 +378,15 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       result = todo("num_teams with multi-dimensional values");
   };
   auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
-    // Check that we don't exceed the maximum supported dimensions (3)
-    if (op.getNumThreadsDimsCount() > 3)
+    if (op.getNumThreadsDimsCount() > 3) {
       result = todo("num_threads with more than 3 dimensions");
+      return;
+    }
+
+    if (op.hasNumThreadsMultiDim() &&
+        !op->template getParentOfType<omp::TargetOp>())
+      result = todo(
+          "num_threads with multi-dimensional values outside target region");
   };
 
   auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
@@ -6538,8 +6544,9 @@ static void extractHostEvalClauses(
                   break;
                 }
               }
-            } else
+            } else {
               llvm_unreachable("unsupported host_eval use");
+            }
           })
           .Case([&](omp::LoopNestOp loopOp) {
             auto processBounds =
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 22ca025ac85d6..1d85806bfaf55 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -457,6 +457,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
 
 // -----
 
+llvm.func @parallel_num_threads_multi_dim_standalone(%lb : i32, %ub : i32) {
+  // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values outside target region in omp.parallel operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
+  omp.parallel num_threads(%lb, %ub : i32, i32) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 llvm.func @parallel_num_threads_too_many_dims(%lb : i32, %ub : i32) {
   // expected-error at below {{not yet implemented: Unhandled clause num_threads with more than 3 dimensions in omp.parallel operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}



More information about the Mlir-commits mailing list