[Mlir-commits] [mlir] [OpenMP][MLIR] Add num_teams mlir to llvm lowering (PR #179418)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Feb 15 22:10:10 PST 2026
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/179418
>From 4b0fabc76650041ceec5f94ee4206031915456e7 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 3 Feb 2026 14:47:58 +0530
Subject: [PATCH] [OpenMP][MLIR] Add num_teams mlir to llvm lowering
---
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 117 +++++++++++++-----
.../LLVMIR/openmp-target-launch-host.mlir | 6 +-
mlir/test/Target/LLVMIR/openmp-todo.mlir | 20 ++-
3 files changed, 104 insertions(+), 39 deletions(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a76419353b1b6..bfe579f5acb5e 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -381,8 +381,28 @@ static LogicalResult checkImplementationStatus(Operation &op) {
result = todo("task_reduction");
};
auto checkNumTeams = [&todo](auto op, LogicalResult &result) {
- if (op.hasNumTeamsMultiDim())
- result = todo("num_teams with multi-dimensional values");
+ if (op.getNumTeamsDimsCount() > 3) {
+ result = todo("num_teams with more than 3 dimensions");
+ return;
+ }
+
+ // Multi-dimensional num_teams is only fully supported within target
+ // regions.
+ if (op.hasNumTeamsMultiDim()) {
+ Operation *parent = op.getOperation()->getParentOp();
+ bool insideTarget = false;
+ while (parent) {
+ if (isa<omp::TargetOp>(parent)) {
+ insideTarget = true;
+ break;
+ }
+ parent = parent->getParentOp();
+ }
+
+ if (!insideTarget)
+ result = todo(
+ "num_teams with multi-dimensional values outside target region");
+ }
};
auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
if (op.hasNumThreadsMultiDim())
@@ -6032,13 +6052,12 @@ createDeviceArgumentAccessor(MapInfoData &mapData, llvm::Argument &arg,
///
/// Loop bounds and steps are only optionally populated, if output vectors are
/// provided.
-static void
-extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
- Value &numTeamsLower, Value &numTeamsUpper,
- Value &threadLimit,
- llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
- llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
- llvm::SmallVectorImpl<Value> *steps = nullptr) {
+static void extractHostEvalClauses(
+ omp::TargetOp targetOp, Value &numThreads, Value &numTeamsLower,
+ llvm::SmallVectorImpl<Value> &numTeamsUpperVars, Value &threadLimit,
+ llvm::SmallVectorImpl<Value> *lowerBounds = nullptr,
+ llvm::SmallVectorImpl<Value> *upperBounds = nullptr,
+ llvm::SmallVectorImpl<Value> *steps = nullptr) {
auto blockArgIface = llvm::cast<omp::BlockArgOpenMPOpInterface>(*targetOp);
for (auto item : llvm::zip_equal(targetOp.getHostEvalVars(),
blockArgIface.getHostEvalBlockArgs())) {
@@ -6050,10 +6069,19 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
if (teamsOp.getNumTeamsLower() == blockArg)
numTeamsLower = hostEvalVar;
else if (llvm::is_contained(teamsOp.getNumTeamsUpperVars(),
- blockArg))
- numTeamsUpper = hostEvalVar;
- else if (!teamsOp.getThreadLimitVars().empty() &&
- teamsOp.getThreadLimit(0) == blockArg)
+ blockArg)) {
+ // Find which dimension this blockArg corresponds to
+ for (auto [i, upperVar] :
+ llvm::enumerate(teamsOp.getNumTeamsUpperVars())) {
+ if (upperVar == blockArg) {
+ if (numTeamsUpperVars.size() <= i)
+ numTeamsUpperVars.resize(i + 1);
+ numTeamsUpperVars[i] = hostEvalVar;
+ break;
+ }
+ }
+ } else if (!teamsOp.getThreadLimitVars().empty() &&
+ teamsOp.getThreadLimit(0) == blockArg)
threadLimit = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6163,19 +6191,22 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
bool isTargetDevice, bool isGPU) {
// TODO: Handle constant 'if' clauses.
- Value numThreads, numTeamsLower, numTeamsUpper, threadLimit;
+ Value numThreads, numTeamsLower, threadLimit;
+ llvm::SmallVector<Value> numTeamsUpperVars;
if (!isTargetDevice) {
- extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
- threadLimit);
+ extractHostEvalClauses(targetOp, numThreads, numTeamsLower,
+ numTeamsUpperVars, threadLimit);
} else {
// In the target device, values for these clauses are not passed as
// host_eval, but instead evaluated prior to entry to the region. This
// ensures values are mapped and available inside of the target region.
if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
numTeamsLower = teamsOp.getNumTeamsLower();
- // Handle num_teams upper bounds (only first value for now)
- if (!teamsOp.getNumTeamsUpperVars().empty())
- numTeamsUpper = teamsOp.getNumTeams(0);
+ // Handle all num_teams upper bound dimensions
+ numTeamsUpperVars.reserve(teamsOp.getNumTeamsUpperVars().size());
+ for (auto upperVar : teamsOp.getNumTeamsUpperVars())
+ numTeamsUpperVars.push_back(upperVar);
+ // Handle thread_limit (only first value for now)
if (!teamsOp.getThreadLimitVars().empty())
threadLimit = teamsOp.getThreadLimit(0);
}
@@ -6188,23 +6219,30 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// Handle clauses impacting the number of teams.
- int32_t minTeamsVal = 1, maxTeamsVal = -1;
+ int32_t minTeamsVal = 1;
+ llvm::SmallVector<int32_t, 3> maxTeamsVals(3, -1);
if (castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
- // TODO: Use `hostNumTeamsLower` to initialize `minTeamsVal`. For now,
+ // TODO: Use `numTeamsLower` to initialize `minTeamsVal`. For now,
// match clang and set min and max to the same value.
- if (numTeamsUpper) {
- if (auto val = extractConstInteger(numTeamsUpper))
- minTeamsVal = maxTeamsVal = *val;
- } else {
- minTeamsVal = maxTeamsVal = 0;
+ if (!numTeamsUpperVars.empty()) {
+ // Handle multi-dimensional num_teams
+ for (auto [i, upperVar] : llvm::enumerate(numTeamsUpperVars)) {
+ if (upperVar) {
+ if (auto val = extractConstInteger(upperVar)) {
+ maxTeamsVals[i] = *val;
+ if (i == 0)
+ minTeamsVal = *val;
+ }
+ }
+ }
}
} else if (castOrGetParentOfType<omp::ParallelOp>(capturedOp,
/*immediateParent=*/true) ||
castOrGetParentOfType<omp::SimdOp>(capturedOp,
/*immediateParent=*/true)) {
- minTeamsVal = maxTeamsVal = 1;
+ minTeamsVal = maxTeamsVals[0] = 1;
} else {
- minTeamsVal = maxTeamsVal = -1;
+ minTeamsVal = maxTeamsVals[0] = -1;
}
// Handle clauses impacting the number of threads.
@@ -6273,7 +6311,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
attrs.ExecFlags = llvm::omp::OMP_TGT_EXEC_MODE_SPMD_NO_LOOP;
attrs.MinTeams = minTeamsVal;
- attrs.MaxTeams.front() = maxTeamsVal;
+ // Always resize to 3 dimensions to match TargetKernelRuntimeAttrs
+ attrs.MaxTeams.resize(3, -1);
+ for (size_t i = 0; i < maxTeamsVals.size() && i < attrs.MaxTeams.size(); ++i)
+ attrs.MaxTeams[i] = maxTeamsVals[i];
attrs.MinThreads = 1;
attrs.MaxThreads.front() = combinedMaxThreadsVal;
attrs.ReductionDataSize = reductionDataSize;
@@ -6297,10 +6338,11 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
omp::LoopNestOp loopOp = castOrGetParentOfType<omp::LoopNestOp>(capturedOp);
unsigned numLoops = loopOp ? loopOp.getNumLoops() : 0;
- Value numThreads, numTeamsLower, numTeamsUpper, teamsThreadLimit;
+ Value numThreads, numTeamsLower, teamsThreadLimit;
+ llvm::SmallVector<Value> numTeamsUpperVars;
llvm::SmallVector<Value> lowerBounds(numLoops), upperBounds(numLoops),
steps(numLoops);
- extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpper,
+ extractHostEvalClauses(targetOp, numThreads, numTeamsLower, numTeamsUpperVars,
teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
// TODO: Handle constant 'if' clauses.
@@ -6317,9 +6359,16 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
attrs.MinTeams = builder.CreateSExtOrTrunc(
moduleTranslation.lookupValue(numTeamsLower), builder.getInt32Ty());
- if (numTeamsUpper)
- attrs.MaxTeams.front() = builder.CreateSExtOrTrunc(
- moduleTranslation.lookupValue(numTeamsUpper), builder.getInt32Ty());
+ // Handle multi-dimensional num_teams upper bounds
+ attrs.MaxTeams.resize(3);
+ if (!numTeamsUpperVars.empty()) {
+ for (auto [i, upperVar] : llvm::enumerate(numTeamsUpperVars)) {
+ if (upperVar) {
+ attrs.MaxTeams[i] = builder.CreateSExtOrTrunc(
+ moduleTranslation.lookupValue(upperVar), builder.getInt32Ty());
+ }
+ }
+ }
if (teamsThreadLimit)
attrs.TeamsThreadLimit.front() = builder.CreateSExtOrTrunc(
diff --git a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
index abc67017b620d..9cc9a5025b613 100644
--- a/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-target-launch-host.mlir
@@ -2,13 +2,13 @@
// CHECK: define void @main(i32 %[[NUM_TEAMS_ARG:.*]])
// CHECK: %[[KERNEL_ARGS:.*]] = alloca %struct.__tgt_kernel_arguments
-// CHECK: %[[NUM_TEAMS:.*]] = insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
+// CHECK: insertvalue [3 x i32] zeroinitializer, i32 %[[NUM_TEAMS_ARG]], 0
// CHECK: %[[NUM_TEAMS_KARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 10
-// CHECK: store [3 x i32] %[[NUM_TEAMS]], ptr %[[NUM_TEAMS_KARG]], align 4
+// CHECK-NEXT: store [3 x i32] %{{.*}}, ptr %[[NUM_TEAMS_KARG]], align 4
// CHECK: %[[NUM_THREADS_ARG:.*]] = getelementptr inbounds nuw %struct.__tgt_kernel_arguments, ptr %[[KERNEL_ARGS]], i32 0, i32 11
-// CHECK: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
+// CHECK-NEXT: store [3 x i32] [i32 10, i32 0, i32 0], ptr %[[NUM_THREADS_ARG]], align 4
// CHECK: %{{.*}} = call i32 @__tgt_target_kernel(ptr {{.*}}, i64 -1, i32 %[[NUM_TEAMS_ARG]], i32 [[NUM_THREADS:10]], ptr @.[[OUTLINED_FN:.*]].region_id, ptr %[[KERNEL_ARGS]])
// CHECK: call void @[[OUTLINED_FN]](i32 %[[NUM_TEAMS_ARG]])
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 70e4edf0705f2..df700cec114f8 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -432,8 +432,8 @@ llvm.func @teams_private(%x : !llvm.ptr) {
// -----
-llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
- // expected-error at below {{not yet implemented: Unhandled clause num_teams with multi-dimensional values in omp.teams operation}}
+llvm.func @teams_num_teams_multi_dim_standalone(%lb : i32, %ub : i32) {
+ // expected-error at below {{not yet implemented: Unhandled clause num_teams with multi-dimensional values outside target region in omp.teams operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.teams}}
omp.teams num_teams(to %ub, %ub, %ub : i32, i32, i32) {
omp.terminator
@@ -443,6 +443,22 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
// -----
+llvm.func @teams_num_teams_too_many_dims() {
+ // expected-error at below {{LLVM Translation failed for operation: omp.target}}
+ omp.target {
+ %c100 = llvm.mlir.constant(100 : i32) : i32
+ // expected-error at below {{not yet implemented: Unhandled clause num_teams with more than 3 dimensions in omp.teams operation}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.teams}}
+ omp.teams num_teams(to %c100, %c100, %c100, %c100 : i32, i32, i32, i32) {
+ omp.terminator
+ }
+ omp.terminator
+ }
+ llvm.return
+}
+
+// -----
+
llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
// expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
More information about the Mlir-commits
mailing list