[Mlir-commits] [mlir] [OpenMP][MLIR] Add thread_limit mlir->llvm lowering (PR #179608)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 15 22:36:49 PST 2026
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/179608
>From 5526cfba4863e18bb37088166327180076a24299 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 4 Feb 2026 10:21:06 +0530
Subject: [PATCH] [OpenMP][MLIR] Add thread_limit mlir->llvm lowering
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 63 +++++++++++++------
.../LLVMIR/openmp-target-launch-host.mlir | 6 +-
mlir/test/Target/LLVMIR/openmp-teams.mlir | 36 +++++++++++
mlir/test/Target/LLVMIR/openmp-todo.mlir | 6 +-
4 files changed, 85 insertions(+), 26 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a76419353b1b6..9d1dbcfb4ad77 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -390,8 +390,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
};
auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
- if (op.hasThreadLimitMultiDim())
- result = todo("thread_limit with multi-dimensional values");
+ if (op.getThreadLimitDimsCount() > 3)
+ result = todo("thread_limit with more than 3 dimensions");
};
LogicalResult result = success();
@@ -6035,7 +6035,7 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
static void
extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
Value &numTeamsLower, Value &numTeamsUpper,
- Value &threadLimit,
+ llvm::SmallVectorImpl<Value> &threadLimitVars,
llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
llvm::SmallVectorImpl<Value> *steps = nullptr) {
@@ -6052,10 +6052,18 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
blockArg))
numTeamsUpper = hostEvalVar;
- else if (!teamsOp.getThreadLimitVars().empty() &&
- teamsOp.getThreadLimit(0) == blockArg)
- threadLimit = hostEvalVar;
- else
+ else if (llvm::is_contained(teamsOp.getThreadLimitVars(),
+ blockArg)) {
+ for (auto [i, limitVar] :
+ llvm::enumerate(teamsOp.getThreadLimitVars())) {
+ if (limitVar == blockArg) {
+ if (threadLimitVars.size() <= i)
+ threadLimitVars.resize(i + 1);
+ threadLimitVars[i] = hostEvalVar;
+ break;
+ }
+ }
+ } else
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
@@ -6163,10 +6171,11 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
bool isTargetDevice, bool isGPU) {
// TODO: Handle constant 'if' clauses.
- Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+ Value numThreads, numTeamsLower, numTeamsUpper;
+ llvm::SmallVector<Value> threadLimitVars;
if (!isTargetDevice) {
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
- threadLimit);
+ threadLimitVars);
} else {
// In the target device, values for these clauses are not passed as
// host_eval, but instead evaluated prior to entry to the region. This
@@ -6176,8 +6185,9 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// Handle num_teams upper bounds (only first value for now)
if (!teamsOp.getNumTeamsUpperVars().empty())
numTeamsUpper = teamsOp.getNumTeams(0);
- if (!teamsOp.getThreadLimitVars().empty())
- threadLimit = teamsOp.getThreadLimit(0);
+ threadLimitVars.reserve(teamsOp.getThreadLimitVars().size());
+ for (auto limitVar : teamsOp.getThreadLimitVars())
+ threadLimitVars.push_back(limitVar);
}
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
@@ -6226,7 +6236,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
if (!targetOp.getThreadLimitVars().empty())
setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
- setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
+ if (!threadLimitVars.empty())
+ setMaxValueFromClause(threadLimitVars[0], teamsThreadLimitVal);
// Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
int32_t maxThreadsVal = -1;
@@ -6275,6 +6286,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
attrs.MinTeams = minTeamsVal;
attrs.MaxTeams.front() = maxTeamsVal;
attrs.MinThreads = 1;
+ attrs.MaxThreads.resize(3, -1);
attrs.MaxThreads.front() = combinedMaxThreadsVal;
attrs.ReductionDataSize = reductionDataSize;
// TODO: Allow modified buffer length similar to
@@ -6297,17 +6309,22 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
- Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+ Value numThreads, numTeamsLower, numTeamsUpper;
+ llvm::SmallVector<Value> threadLimitVars;
llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
steps(numLoops);
extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
- teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
+ threadLimitVars, &lowerBounds, &upperBounds, &steps);
// TODO: Handle constant 'if' clauses.
+ // Resize to 3 dimensions to match TargetKernelDefaultAttrs
+ attrs.TargetThreadLimit.resize(3);
if (!targetOp.getThreadLimitVars().empty()) {
- Value targetThreadLimit = targetOp.getThreadLimit(0);
- attrs.TargetThreadLimit.front() =
- moduleTranslation.lookupValue(targetThreadLimit);
+ for (auto [i, limitVar] : llvm::enumerate(targetOp.getThreadLimitVars())) {
+ if (limitVar) {
+ attrs.TargetThreadLimit[i] = moduleTranslation.lookupValue(limitVar);
+ }
+ }
}
// The __kmpc_push_num_teams_51 function expects int32 as the arguments. So,
@@ -6321,9 +6338,15 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
- if (teamsThreadLimit)
- attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
- moduleTranslation.lookupValue(teamsThreadLimit), builder.getInt32Ty());
+ attrs.TeamsThreadLimit.resize(3);
+ if (!threadLimitVars.empty()) {
+ for (auto [i, limitVar] : llvm::enumerate(threadLimitVars)) {
+ if (limitVar) {
+ attrs.TeamsThreadLimit[i] = builder.CreateSExtOrTrunc(
+ moduleTranslation.lookupValue(limitVar), builder.getInt32Ty());
+ }
+ }
+ }
if (numThreads)
attrs.MaxThreads = moduleTranslation.lookupValue(numThreads);
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
index abc67017b620d..9cc9a5025b613 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
@@ -2,13 +2,13 @@
// CHECK: define void @main(i32 %[[NUM_TEAMS_ARG:.*]])
// CHECK: %[[KERNEL_ARGS:.*]] = alloca %struct.__tgt_kernel_arguments
-// CHECK: %[[NUM_TEAMS:.*]] = insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
+// CHECK: insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
// CHECK: %[[NUM_TEAMS_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 10
-// CHECK: store [3 x i32] %[[NUM_TEAMS]], ptr %[[NUM_TEAMS_KARG]], align 4
+// CHECK-NEXT: store [3 x i32] %{{.*}}, ptr %[[NUM_TEAMS_KARG]], align 4
// CHECK: %[[NUM_THREADS_ARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 11
-// CHECK: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
+// CHECK-NEXT: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
// CHECK: %{{.*}} = call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 %[[NUM_TEAMS_ARG]], i32 [[NUM_THREADS:10]], ptr @.[[OUTLINED_FN:.*]].region_id, ptr %[[KERNEL_ARGS]])
// CHECK: call void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_ARG]])
diff --git a/mlir/test/Target/LLVMIR/openmp-teams.mlir b/mlir/test/Target/LLVMIR/openmp-teams.mlir
index 4690b51122beb..adca15d1c5fcc 100644
--- a/mlir/test/Target/LLVMIR/openmp-teams.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-teams.mlir
@@ -311,3 +311,39 @@ llvm.func @teams_if_with_num_teams(%condition: i1, %numTeamsLower: i32, %numTeam
llvm.call @afterTeams() : () -> ()
llvm.return
}
+
+// -----
+
+llvm.func @duringTeams()
+
+// CHECK-LABEL: @omp_teams_thread_limit_2d
+// CHECK-SAME: (i32 [[LIMIT_X:.+]], i32 [[LIMIT_Y:.+]])
+llvm.func @omp_teams_thread_limit_2d(%limitX: i32, %limitY: i32) {
+ // Multi-dimensional thread_limit: all dimensions are passed
+ // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+ // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[LIMIT_X]])
+ // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @{{[0-9]+}}, i32 0, ptr [[OUTLINED_FN:.+]])
+ omp.teams thread_limit(%limitX, %limitY : i32, i32) {
+ llvm.call @duringTeams() : () -> ()
+ omp.terminator
+ }
+ llvm.return
+}
+
+// -----
+
+llvm.func @duringTeams()
+
+// CHECK-LABEL: @omp_teams_thread_limit_3d
+// CHECK-SAME: (i32 [[LIMIT_X:.+]], i64 [[LIMIT_Y:.+]], i16 [[LIMIT_Z:.+]])
+llvm.func @omp_teams_thread_limit_3d(%limitX: i32, %limitY: i64, %limitZ: i16) {
+ // Multi-dimensional thread_limit with mixed types: all dimensions are passed
+ // CHECK: [[THREAD_NUM:%.+]] = call i32 @__kmpc_global_thread_num
+ // CHECK-NEXT: call void @__kmpc_push_num_teams_51({{.+}}, i32 [[THREAD_NUM]], i32 0, i32 0, i32 [[LIMIT_X]])
+ // CHECK: call void (ptr, i32, ptr, ...) @__kmpc_fork_teams(ptr @{{[0-9]+}}, i32 0, ptr [[OUTLINED_FN:.+]])
+ omp.teams thread_limit(%limitX, %limitY, %limitZ : i32, i64, i16) {
+ llvm.call @duringTeams() : () -> ()
+ omp.terminator
+ }
+ llvm.return
+}
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 70e4edf0705f2..384ad4a65cf39 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -454,10 +454,10 @@ llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
// -----
-llvm.func @teams_thread_limit_multi_dim(%lb : i32, %ub : i32) {
- // expected-error at below {{not yet implemented: Unhandled clause thread_limit with multi-dimensional values in omp.teams operation}}
+llvm.func @teams_thread_limit_too_many_dims(%lb : i32, %ub : i32) {
+ // expected-error at below {{not yet implemented: Unhandled clause thread_limit with more than 3 dimensions in omp.teams operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.teams}}
- omp.teams thread_limit(%lb, %ub : i32, i32) {
+ omp.teams thread_limit(%lb, %ub, %lb, %ub : i32, i32, i32, i32) {
omp.terminator
}
llvm.return
More information about the Mlir-commits
mailing list