[Mlir-commits] [mlir] [OpenMP][MLIR] Add thread_limit mlir->llvm lowering (PR #179608)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Feb 10 05:31:22 PST 2026
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-flang-openmp
Author: Chaitanya (skc7)
<details>
<summary>Changes</summary>
`thread_limit` in omp mlir now supports multi-dimension.
This PR updates OpenMPToLLVMIRTranslation with below changes:
Updates `checkThreadLimit` to allow mult-dim `thread_limit` upto 3 dimensions.
Updates `initTargetDefaultAttrs` to extract all `thread_limit` dimensions from `teamsOp.getThreadLimitVars()`. Uses first dimension for compile-time constant extraction.
Updates `initTargetRuntimeAttrs` to extracts all dimensions into `threadLimitVars` vector.
---
Full diff: https://github.com/llvm/llvm-project/pull/179608.diff
4 Files Affected:
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+47-20)
- (modified) mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir (+3-3)
- (modified) mlir/test/Target/LLVMIR/openmp-teams.mlir (+36)
- (modified) mlir/test/Target/LLVMIR/openmp-todo.mlir (+3-3)
``````````diff
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 022322502a755..f364e1e1818c3 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -386,8 +386,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
};
auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
- if (op.hasThreadLimitMultiDim())
- result = todo("thread_limit with multi-dimensional values");
+ if (op.getThreadLimitDimsCount() > 3)
+ result = todo("thread_limit with more than 3 dimensions");
};
LogicalResult result = success();
@@ -6040,7 +6040,7 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
static void
extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
Value &numTeamsLower, Value &numTeamsUpper,
- Value &threadLimit,
+ llvm::SmallVectorImpl<Value> &threadLimitVars,
llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
llvm::SmallVectorImpl<Value> *steps = nullptr) {
@@ -6057,10 +6057,18 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
blockArg))
numTeamsUpper = hostEvalVar;
- else if (!teamsOp.getThreadLimitVars().empty() &&
- teamsOp.getThreadLimit(0) == blockArg)
- threadLimit = hostEvalVar;
- else
+ else if (llvm::is_contained(teamsOp.getThreadLimitVars(),
+ blockArg)) {
+ for (auto [i, limitVar] :
+ llvm::enumerate(teamsOp.getThreadLimitVars())) {
+ if (limitVar == blockArg) {
+ if (threadLimitVars.size() <= i)
+ threadLimitVars.resize(i + 1);
+ threadLimitVars[i] = hostEvalVar;
+ break;
+ }
+ }
+ } else
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
@@ -6168,10 +6176,11 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
bool isTargetDevice, bool isGPU) {
// TODO: Handle constant 'if' clauses.
- Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+ Value numThreads, numTeamsLower, numTeamsUpper;
+ llvm::SmallVector<Value> threadLimitVars;
if (!isTargetDevice) {
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
- threadLimit);
+ threadLimitVars);
} else {
// In the target device, values for these clauses are not passed as
// host_eval, but instead evaluated prior to entry to the region. This
@@ -6181,8 +6190,9 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// Handle num_teams upper bounds (only first value for now)
if (!teamsOp.getNumTeamsUpperVars().empty())
numTeamsUpper = teamsOp.getNumTeams(0);
- if (!teamsOp.getThreadLimitVars().empty())
- threadLimit = teamsOp.getThreadLimit(0);
+ threadLimitVars.reserve(teamsOp.getThreadLimitVars().size());
+ for (auto limitVar : teamsOp.getThreadLimitVars())
+ threadLimitVars.push_back(limitVar);
}
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
@@ -6231,7 +6241,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
if (!targetOp.getThreadLimitVars().empty())
setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
- setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
+ if (!threadLimitVars.empty())
+ setMaxValueFromClause(threadLimitVars[0], teamsThreadLimitVal);
// Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
int32_t maxThreadsVal = -1;
@@ -6278,8 +6289,11 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
attrs.MinTeams = minTeamsVal;
+ // Always resize to 3 dimensions to match TargetKernelRuntimeAttrs
+ attrs.MaxTeams.resize(3, -1);
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
+ attrs.MaxThreads.resize(3, -1);
attrs.MaxThreads.front() = combinedMaxThreadsVal;
attrs.ReductionDataSize = reductionDataSize;
// TODO: Allow modified buffer length similar to
@@ -6302,17 +6316,22 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
- Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+ Value numThreads, numTeamsLower, numTeamsUpper;
+ llvm::SmallVector<Value> threadLimitVars;
llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
steps(numLoops);
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
- teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
+ threadLimitVars, &lowerBounds, &upperBounds, &steps);
// TODO: Handle constant 'if' clauses.
+ // Resize to 3 dimensions to match TargetKernelDefaultAttrs
+ attrs.TargetThreadLimit.resize(3);
if (!targetOp.getThreadLimitVars().empty()) {
- Value targetThreadLimit = targetOp.getThreadLimit(0);
- attrs.TargetThreadLimit.front() =
- moduleTranslation.lookupValue(targetThreadLimit);
+ for (auto [i, limitVar] : llvm::enumerate(targetOp.getThreadLimitVars())) {
+ if (limitVar) {
+ attrs.TargetThreadLimit[i] = moduleTranslation.lookupValue(limitVar);
+ }
+ }
}
// The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
@@ -6322,13 +6341,21 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
attrs.MinTeams = builder.CreateSExtOrTrunc(
moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
+ // Resize to 3 dimensions to match TargetKernelDefaultAttrs
+ attrs.MaxTeams.resize(3);
if (numTeamsUpper)
attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
- if (teamsThreadLimit)
- attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
- moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
+ attrs.TeamsThreadLimit.resize(3);
+ if (!threadLimitVars.empty()) {
+ for (auto [i, limitVar] : llvm::enumerate(threadLimitVars)) {
+ if (limitVar) {
+ attrs.TeamsThreadLimit[i] = builder.CreateSExtOrTrunc(
+ moduleTranslation.lookupValue(limitVar), builder.getInt32Ty());
+ }
+ }
+ }
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
index abc67017b620d..9cc9a5025b613 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
@@ -2,13 +2,13 @@
// CHECK: define void @main(i32 %[[NUM_TEAMS_ARG:.*]])
// CHECK: %[[KERNEL_ARGS:.*]] = alloca %struct.__tgt_kernel_arguments
-// CHECK: %[[NUM_TEAMS:.*]] = insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
+// CHECK: insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
// CHECK: %[[NUM_TEAMS_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 10
-// CHECK: store [3 x i32] %[[NUM_TEAMS]], ptr %[[NUM_TEAMS_KARG]], align 4
+// CHECK-NEXT: store [3 x i32] %{{.*}}, ptr %[[NUM_TEAMS_KARG]], align 4
// CHECK: %[[NUM_THREADS_ARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 11
-// CHECK: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
+// CHECK-NEXT: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
// CHECK: %{{.*}} = call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 %[[NUM_TEAMS_ARG]], i32 [[NUM_THREADS:10]], ptr @.[[OUTLINED_FN:.*]].region_id, ptr %[[KERNEL_ARGS]])
// CHECK: call void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_ARG]])
diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir
index 4690b51122beb..adca15d1c5fcc 100644
--- a/mlir/test/Target/LLVMIR/openmp-teams.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir
@@ -311,3 +311,39 @@ llvm.func @teams_if_with_num_teams(%condition: i1, %numTeamsLower: i32, %numTeam
llvm.call @afterTeams() : () -> ()
llvm.return
}
+
+// -----
+
+llvm.func @duringTeams()
+
+// CHECK-LABEL: @omp_teams_thread_limit_2d
+// CHECK-SAME: (i32 [[LIMIT_X:.+]], i32 [[LIMIT_Y:.+]])
+llvm.func @omp_teams_thread_limit_2d(%limitX: i32, %limitY: i32) {
+ // Multi-dimensional thread_limit: all dimensions are passed
+ // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+ // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[LIMIT_X]])
+ // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @{{[0-9]+}}, i32 0, ptr [[OUTLINED_FN:.+]])
+ omp.teams thread_limit(%limitX, %limitY : i32, i32) {
+ llvm.call @duringTeams() : () -> ()
+ omp.terminator
+ }
+ llvm.return
+}
+
+// -----
+
+llvm.func @duringTeams()
+
+// CHECK-LABEL: @omp_teams_thread_limit_3d
+// CHECK-SAME: (i32 [[LIMIT_X:.+]], i64 [[LIMIT_Y:.+]], i16 [[LIMIT_Z:.+]])
+llvm.func @omp_teams_thread_limit_3d(%limitX: i32, %limitY: i64, %limitZ: i16) {
+ // Multi-dimensional thread_limit with mixed types: all dimensions are passed
+ // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+ // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[LIMIT_X]])
+ // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @{{[0-9]+}}, i32 0, ptr [[OUTLINED_FN:.+]])
+ omp.teams thread_limit(%limitX, %limitY, %limitZ : i32, i64, i16) {
+ llvm.call @duringTeams() : () -> ()
+ omp.terminator
+ }
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 36338e5cb1bed..ec1c134c2c07a 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -454,10 +454,10 @@ llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
// -----
-llvm.func @teams_thread_limit_multi_dim(%lb : i32, %ub : i32) {
- // expected-error at below {{not yet implemented: Unhandled clause thread_limit with multi-dimensional values in omp.teams operation}}
+llvm.func @teams_thread_limit_too_many_dims(%lb : i32, %ub : i32) {
+ // expected-error at below {{not yet implemented: Unhandled clause thread_limit with more than 3 dimensions in omp.teams operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.teams}}
- omp.teams thread_limit(%lb, %ub : i32, i32) {
+ omp.teams thread_limit(%lb, %ub, %lb, %ub : i32, i32, i32, i32) {
omp.terminator
}
llvm.return
``````````
</details>
https://github.com/llvm/llvm-project/pull/179608
More information about the Mlir-commits
mailing list