[Mlir-commits] [mlir] [OpenMP][MLIR] Add num_threads mlir->llvm lowering (PR #179420)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Feb 16 00:52:19 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openmp
@llvm/pr-subscribers-mlir-llvm
@llvm/pr-subscribers-mlir
Author: Chaitanya (skc7)
<details>
<summary>Changes</summary>
`num_threads` clause in omp mlir now supports multi-dimension.
Changes:
- Update checkNumThreads validation to allow multi-dimensional num_threads up to 3 dimensions inside target regions
- Update extractHostEvalClauses to extract all num_threads dimensions from parallelOp.getNumThreadsVars() into a vector
- Update initTargetDefaultAttrs to handle multi-dimensional `num_threads` extraction. Currently uses first dimension for maxThreadsVal since the runtime supports only 1D.
- Update initTargetRuntimeAttrs to extract all dimensions into `numThreadsVars` vector. Passes first dimension to attrs.MaxThreads.
---
Full diff: https://github.com/llvm/llvm-project/pull/179420.diff
2 Files Affected:
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+40-26)
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (+3-3)
``````````diff
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 022322502a755..557a0e46ec4ba 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -381,8 +381,9 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = todo("num_teams with multi-dimensional values");
};
auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
- if (op.hasNumThreadsMultiDim())
- result = todo("num_threads with multi-dimensional values");
+ // Check that we don't exceed the maximum supported dimensions (3)
+ if (op.getNumThreadsDimsCount() > 3)
+ result = todo("num_threads with more than 3 dimensions");
};
auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
@@ -6037,13 +6038,12 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
///
/// Loop bounds and steps are only optionally populated, if output vectors are
/// provided.
-static void
-extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
- Value &numTeamsLower, Value &numTeamsUpper,
- Value &threadLimit,
- llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
- llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
- llvm::SmallVectorImpl<Value> *steps = nullptr) {
+static void extractHostEvalClauses(
+ omp::TargetOp targetOp, llvm::SmallVectorImpl<Value> &numThreadsVars,
+ Value &numTeamsLower, Value &numTeamsUpper, Value &threadLimit,
+ llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
+ llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
+ llvm::SmallVectorImpl<Value> *steps = nullptr) {
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
blockArgIface.getHostEvalBlockArgs())) {
@@ -6064,10 +6064,17 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
- if (!parallelOp.getNumThreadsVars().empty() &&
- parallelOp.getNumThreads(0) == blockArg)
- numThreads = hostEvalVar;
- else
+ if (llvm::is_contained(parallelOp.getNumThreadsVars(), blockArg)) {
+ for (auto [i, threadsVar] :
+ llvm::enumerate(parallelOp.getNumThreadsVars())) {
+ if (threadsVar == blockArg) {
+ if (numThreadsVars.size() <= i)
+ numThreadsVars.resize(i + 1);
+ numThreadsVars[i] = hostEvalVar;
+ break;
+ }
+ }
+ } else
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::LoopNestOp loopOp) {
@@ -6168,10 +6175,11 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
bool isTargetDevice, bool isGPU) {
// TODO: Handle constant 'if' clauses.
- Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+ Value numTeamsLower, numTeamsUpper, threadLimit;
+ llvm::SmallVector<Value> numThreadsVars;
if (!isTargetDevice) {
- extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
- threadLimit);
+ extractHostEvalClauses(targetOp, numThreadsVars, numTeamsLower,
+ numTeamsUpper, threadLimit);
} else {
// In the target device, values for these clauses are not passed as
// host_eval, but instead evaluated prior to entry to the region. This
@@ -6186,8 +6194,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
}
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
- if (!parallelOp.getNumThreadsVars().empty())
- numThreads = parallelOp.getNumThreads(0);
+ // Handle multi-dimensional num_threads
+ numThreadsVars.reserve(parallelOp.getNumThreadsVars().size());
+ for (auto threadsVar : parallelOp.getNumThreadsVars())
+ numThreadsVars.push_back(threadsVar);
}
}
@@ -6235,10 +6245,12 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
int32_t maxThreadsVal = -1;
- if (castOrGetParentOfType<omp::ParallelOp>(capturedOp))
- setMaxValueFromClause(numThreads, maxThreadsVal);
- else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
- /*immediateParent=*/true))
+ if (castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+ // For multi-dimensional num_threads, only use the first dimension for now
+ if (!numThreadsVars.empty())
+ setMaxValueFromClause(numThreadsVars[0], maxThreadsVal);
+ } else if (castOrGetParentOfType<omp::SimdOp>(capturedOp,
+ /*immediateParent=*/true))
maxThreadsVal = 1;
// For max values, < 0 means unset, == 0 means set but unknown. Select the
@@ -6302,10 +6314,11 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
- Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+ Value numTeamsLower, numTeamsUpper, teamsThreadLimit;
+ llvm::SmallVector<Value> numThreadsVars;
llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
steps(numLoops);
- extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
+ extractHostEvalClauses(targetOp, numThreadsVars, numTeamsLower, numTeamsUpper,
teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
// TODO: Handle constant 'if' clauses.
@@ -6330,8 +6343,9 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
- if (numThreads)
- attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
+ // Handle multi-dimensional num_threads (only first value for now)
+ if (!numThreadsVars.empty())
+ attrs.MaxThreads = moduleTranslation.lookupValue(numThreadsVars[0]);
if (omp::bitEnumContainsAny(targetOp.getKernelExecFlags(capturedOp),
omp::TargetRegionFlags::trip_count)) {
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 36338e5cb1bed..fe463da873461 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -443,10 +443,10 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
// -----
-llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
- // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}}
+llvm.func @parallel_num_threads_too_many_dims(%lb : i32, %ub : i32) {
+ // expected-error at below {{not yet implemented: Unhandled clause num_threads with more than 3 dimensions in omp.parallel operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
- omp.parallel num_threads(%lb, %ub : i32, i32) {
+ omp.parallel num_threads(%lb, %ub, %lb, %ub : i32, i32, i32, i32) {
omp.terminator
}
llvm.return
``````````
</details>
https://github.com/llvm/llvm-project/pull/179420
More information about the Mlir-commits
mailing list