[Mlir-commits] [mlir] [OpenMP][MLIR] Add num_threads mlir->llvm lowering (PR #179420)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Thu Apr 16 03:33:08 PDT 2026
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/179420
>From 94a5614319bd9435d3e675d0d8e5113f2611cc84 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 3 Feb 2026 15:29:10 +0530
Subject: [PATCH 1/2] [OpenMP][MLIR] Add num_threads mlir->llvm lowering
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 66 +++++++++++--------
mlir/test/Target/LLVMIR/openmp-todo.mlir | 6 +-
2 files changed, 43 insertions(+), 29 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2e15f4de4545d..3bb0a181529de 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -378,8 +378,9 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = todo("num_teams with multi-dimensional values");
};
auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
- if (op.hasNumThreadsMultiDim())
- result = todo("num_threads with multi-dimensional values");
+ // Check that we don't exceed the maximum supported dimensions (3)
+ if (op.getNumThreadsDimsCount() > 3)
+ result = todo("num_threads with more than 3 dimensions");
};
auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
@@ -6501,13 +6502,12 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
///
/// Loop bounds and steps are only optionally populated, if output vectors are
/// provided.
-static void
-extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
- Value &numTeamsLower, Value &numTeamsUpper,
- Value &threadLimit,
- llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
- llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
- llvm::SmallVectorImpl<Value> *steps = nullptr) {
+static void extractHostEvalClauses(
+ omp::TargetOp targetOp, llvm::SmallVectorImpl<Value> &numThreadsVars,
+ Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit,
+ llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
+ llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
+ llvm::SmallVectorImpl<Value> *steps = nullptr) {
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
blockArgIface.getHostEvalBlockArgs())) {
@@ -6528,10 +6528,17 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
- if (!parallelOp.getNumThreadsVars().empty() &&
- parallelOp.getNumThreads(0) == blockArg)
- numThreads = hostEvalVar;
- else
+ if (llvm::is_contained(parallelOp.getNumThreadsVars(), blockArg)) {
+ for (auto [i, threadsVar] :
+ llvm::enumerate(parallelOp.getNumThreadsVars())) {
+ if (threadsVar == blockArg) {
+ if (numThreadsVars.size() <= i)
+ numThreadsVars.resize(i + 1);
+ numThreadsVars[i] = hostEvalVar;
+ break;
+ }
+ }
+ } else
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::LoopNestOp loopOp) {
@@ -6639,10 +6646,11 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
bool isTargetDevice, bool isGPU) {
// TODO: Handle constant 'if' clauses.
- Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+ Value numTeamsLower, numTeamsUpper, threadLimit;
+ llvm::SmallVector<Value> numThreadsVars;
if (!isTargetDevice) {
- extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
- threadLimit);
+ extractHostEvalClauses(targetOp, numThreadsVars, numTeamsLower,
+ numTeamsUpper, threadLimit);
} else {
// In the target device, values for these clauses are not passed as
// host_eval, but instead evaluated prior to entry to the region. This
@@ -6657,8 +6665,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
}
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
- if (!parallelOp.getNumThreadsVars().empty())
- numThreads = parallelOp.getNumThreads(0);
+ // Handle multi-dimensional num_threads
+ numThreadsVars.reserve(parallelOp.getNumThreadsVars().size());
+ for (auto threadsVar : parallelOp.getNumThreadsVars())
+ numThreadsVars.push_back(threadsVar);
}
}
@@ -6706,10 +6716,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
int32_t maxThreadsVal = -1;
- if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
- setMaxValueFromClause(numThreads, maxThreadsVal);
- else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
- /*immediateParent=*/true))
+ if (castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+ // For multi-dimensional num_threads, only use the first dimension for now
+ if (!numThreadsVars.empty())
+ setMaxValueFromClause(numThreadsVars[0], maxThreadsVal);
+ } else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
+ /*immediateParent=*/true))
maxThreadsVal = 1;
// For max values, < 0 means unset, == 0 means set but unknown. Select the
@@ -6773,10 +6785,11 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
- Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+ Value numTeamsLower, numTeamsUpper, teamsThreadLimit;
+ llvm::SmallVector<Value> numThreadsVars;
llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
steps(numLoops);
- extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
+ extractHostEvalClauses(targetOp, numThreadsVars, numTeamsLower, numTeamsUpper,
teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
// TODO: Handle constant 'if' clauses.
@@ -6801,8 +6814,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
- if (numThreads)
- attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
+ // Handle multi-dimensional num_threads (only first value for now)
+ if (!numThreadsVars.empty())
+ attrs.MaxThreads = moduleTranslation.lookupValue(numThreadsVars[0]);
if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
omp::TargetRegionFlags::trip_count)) {
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index e0872226531e6..22ca025ac85d6 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -457,10 +457,10 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
// -----
-llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
- // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}}
+llvm.func @parallel_num_threads_too_many_dims(%lb : i32, %ub : i32) {
+ // expected-error at below {{not yet implemented: Unhandled clause num_threads with more than 3 dimensions in omp.parallel operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
- omp.parallel num_threads(%lb, %ub : i32, i32) {
+ omp.parallel num_threads(%lb, %ub, %lb, %ub : i32, i32, i32, i32) {
omp.terminator
}
llvm.return
>From bbb3f8c7a20cca89801cf00d0250e0c21681aa02 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 10 Apr 2026 11:31:59 +0530
Subject: [PATCH 2/2] add todo test
---
.../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 13 ++++++++++---
mlir/test/Target/LLVMIR/openmp-todo.mlir | 11 +++++++++++
2 files changed, 21 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 3bb0a181529de..64c7e5700c771 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -378,9 +378,15 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = todo("num_teams with multi-dimensional values");
};
auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
- // Check that we don't exceed the maximum supported dimensions (3)
- if (op.getNumThreadsDimsCount() > 3)
+ if (op.getNumThreadsDimsCount() > 3) {
result = todo("num_threads with more than 3 dimensions");
+ return;
+ }
+
+ if (op.hasNumThreadsMultiDim() &&
+ !op->template getParentOfType<omp::TargetOp>())
+ result = todo(
+ "num_threads with multi-dimensional values outside target region");
};
auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
@@ -6538,8 +6544,9 @@ static void extractHostEvalClauses(
break;
}
}
- } else
+ } else {
llvm_unreachable("unsupported host_eval use");
+ }
})
.Case([&](omp::LoopNestOp loopOp) {
auto processBounds =
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 22ca025ac85d6..1d85806bfaf55 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -457,6 +457,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
// -----
+llvm.func @parallel_num_threads_multi_dim_standalone(%lb : i32, %ub : i32) {
+ // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values outside target region in omp.parallel operation}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
+ omp.parallel num_threads(%lb, %ub : i32, i32) {
+ omp.terminator
+ }
+ llvm.return
+}
+
+// -----
+
llvm.func @parallel_num_threads_too_many_dims(%lb : i32, %ub : i32) {
// expected-error at below {{not yet implemented: Unhandled clause num_threads with more than 3 dimensions in omp.parallel operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
More information about the Mlir-commits
mailing list