[Mlir-commits] [mlir] 08654ad - [OpenMP][MLIR] Add num_threads clause with dims modifier support (#171767)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 27 02:01:03 PST 2026
Author: Chaitanya
Date: 2026-01-27T15:30:55+05:30
New Revision: 08654adc62c1d7c9a0f8fe40138fe4317a941f3b
URL: https://github.com/llvm/llvm-project/commit/08654adc62c1d7c9a0f8fe40138fe4317a941f3b
DIFF: https://github.com/llvm/llvm-project/commit/08654adc62c1d7c9a0f8fe40138fe4317a941f3b.diff
LOG: [OpenMP][MLIR] Add num_threads clause with dims modifier support (#171767)
PR adds support of openmp 6.1 feature num_threads with dims modifier.
llvmIR translation for num_threads with dims modifier is marked as NYI.
Added:
flang/test/Lower/OpenMP/num-threads-dims.f90
Modified:
flang/lib/Lower/OpenMP/ClauseProcessor.cpp
flang/lib/Lower/OpenMP/Clauses.cpp
flang/lib/Lower/OpenMP/OpenMP.cpp
llvm/include/llvm/Frontend/OpenMP/ClauseT.h
mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
mlir/test/Dialect/OpenMP/ops.mlir
mlir/test/Target/LLVMIR/openmp-todo.mlir
Removed:
################################################################################
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index a067db300a31d..b29cfac857841 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -525,9 +525,11 @@ bool ClauseProcessor::processNumThreads(
lower::StatementContext &stmtCtx,
mlir::omp::NumThreadsClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
- // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result.numThreads =
- fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ // OMPIRBuilder expects `NUM_THREADS` clause as a list of Values.
+ for (const ExprTy &expr : clause->v) {
+ result.numThreadsVars.push_back(
+ fir::getBase(converter.genExprValue(expr, stmtCtx)));
+ }
return true;
}
return false;
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index e588d33bda4e5..9a20fb5d006f8 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1302,9 +1302,16 @@ NumTeams make(const parser::OmpClause::NumTeams &inp,
NumThreads make(const parser::OmpClause::NumThreads &inp,
semantics::SemanticsContext &semaCtx) {
// inp.v -> parser::OmpNumThreadsClause
- auto &t1 = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
- assert(!t1.empty());
- return NumThreads{/*Nthreads=*/makeExpr(t1.front(), semaCtx)};
+ // With dims modifier (OpenMP 6.1): multiple values
+ // Without dims modifier: single value
+ auto &values = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
+ assert(!values.empty());
+
+ List<NumThreads::Nthreads> v;
+ for (const auto &val : values) {
+ v.push_back(makeExpr(val, semaCtx));
+ }
+ return NumThreads{/*Nthreads=*/v};
}
// OmpxAttribute: empty
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 8757783945bd9..1902bd7a21f4a 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -100,8 +100,8 @@ class HostEvalInfo {
for (auto numTeamsUpper : ops.numTeamsUpperVars)
vars.push_back(numTeamsUpper);
- if (ops.numThreads)
- vars.push_back(ops.numThreads);
+ for (auto numThreads : ops.numThreadsVars)
+ vars.push_back(numThreads);
if (ops.threadLimit)
vars.push_back(ops.threadLimit);
@@ -116,7 +116,7 @@ class HostEvalInfo {
assert(args.size() ==
ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
- ops.numTeamsUpperVars.size() + (ops.numThreads ? 1 : 0) +
+ ops.numTeamsUpperVars.size() + ops.numThreadsVars.size() +
(ops.threadLimit ? 1 : 0) &&
"invalid block argument list");
int argIndex = 0;
@@ -135,8 +135,8 @@ class HostEvalInfo {
for (size_t i = 0; i < ops.numTeamsUpperVars.size(); ++i)
ops.numTeamsUpperVars[i] = args[argIndex++];
- if (ops.numThreads)
- ops.numThreads = args[argIndex++];
+ for (size_t i = 0; i < ops.numThreadsVars.size(); ++i)
+ ops.numThreadsVars[i] = args[argIndex++];
if (ops.threadLimit)
ops.threadLimit = args[argIndex++];
@@ -170,13 +170,13 @@ class HostEvalInfo {
/// \returns whether an update was performed. If not, these clauses were not
/// evaluated in the host device.
bool apply(mlir::omp::ParallelOperands &clauseOps) {
- if (!ops.numThreads || parallelApplied) {
+ if (ops.numThreadsVars.empty() || parallelApplied) {
parallelApplied = true;
return false;
}
parallelApplied = true;
- clauseOps.numThreads = ops.numThreads;
+ clauseOps.numThreadsVars = ops.numThreadsVars;
return true;
}
diff --git a/flang/test/Lower/OpenMP/num-threads-dims.f90 b/flang/test/Lower/OpenMP/num-threads-dims.f90
new file mode 100644
index 0000000000000..f3a8d706b7283
--- /dev/null
+++ b/flang/test/Lower/OpenMP/num-threads-dims.f90
@@ -0,0 +1,61 @@
+! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=61 %s -o - | FileCheck %s
+
+!===============================================================================
+! `num_threads` clause with dims modifier (OpenMP 6.1)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPparallel_numthreads_dims4
+subroutine parallel_numthreads_dims4()
+ ! CHECK: omp.parallel
+ ! CHECK-SAME: num_threads(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32, i32)
+ !$omp parallel num_threads(dims(4): 4, 5, 6, 7)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end parallel
+end subroutine parallel_numthreads_dims4
+
+! CHECK-LABEL: func @_QPparallel_numthreads_dims2
+subroutine parallel_numthreads_dims2()
+ ! CHECK: omp.parallel
+ ! CHECK-SAME: num_threads(%{{.*}}, %{{.*}} : i32, i32)
+ !$omp parallel num_threads(dims(2): 8, 4)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end parallel
+end subroutine parallel_numthreads_dims2
+
+! CHECK-LABEL: func @_QPparallel_numthreads_dims_var
+subroutine parallel_numthreads_dims_var(a, b, c)
+ integer, intent(in) :: a, b, c
+ ! CHECK: omp.parallel
+ ! CHECK-SAME: num_threads(%{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32)
+ !$omp parallel num_threads(dims(3): a, b, c)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end parallel
+end subroutine parallel_numthreads_dims_var
+
+!===============================================================================
+! `num_threads` clause without dims modifier (legacy)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPparallel_numthreads_legacy
+subroutine parallel_numthreads_legacy(n)
+ integer, intent(in) :: n
+ ! CHECK: omp.parallel
+ ! CHECK-SAME: num_threads(%{{.*}} : i32)
+ !$omp parallel num_threads(n)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end parallel
+end subroutine parallel_numthreads_legacy
+
+! CHECK-LABEL: func @_QPparallel_numthreads_const
+subroutine parallel_numthreads_const()
+ ! CHECK: omp.parallel
+ ! CHECK-SAME: num_threads(%{{.*}} : i32)
+ !$omp parallel num_threads(16)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end parallel
+end subroutine parallel_numthreads_const
diff --git a/llvm/include/llvm/Frontend/OpenMP/ClauseT.h b/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
index 32cb1be416eb0..86353f0fe0bda 100644
--- a/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
+++ b/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
@@ -1005,11 +1005,13 @@ struct NumTeamsT {
};
// V5.2: [10.1.2] `num_threads` clause
+// V6.1: Extended with dims modifier support
template <typename T, typename I, typename E> //
struct NumThreadsT {
using Nthreads = E;
+ using List = ListT<Nthreads>;
using WrapperTrait = std::true_type;
- Nthreads v;
+ List v;
};
template <typename T, typename I, typename E> //
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 2f7169c6f3e2c..0bfe36648c54d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1107,16 +1107,47 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
- Optional<IntLikeType>:$num_threads
+ Variadic<IntLikeType>:$num_threads_vars
);
let optAssemblyFormat = [{
- `num_threads` `(` $num_threads `:` type($num_threads) `)`
+ `num_threads` `(` $num_threads_vars `:` type($num_threads_vars) `)`
}];
let description = [{
- The optional `num_threads` parameter specifies the number of threads which
- should be used to execute the parallel region.
+ The `num_threads` clause specifies the number of threads.
+
+ Multi-dimensional format (dims modifier):
+ - Multiple values can be specified for multi-dimensional thread counts.
+ - The number of dimensions is derived from the number of values.
+ - Values can have
diff erent integer types.
+ - Format: `num_threads(%v1, %v2, ... : type1, type2, ...)`
+ - Example: `num_threads(%n, %m : i32, i64)`
+
+ Single value format:
+ - A single value specifies the number of threads.
+ - Format: `num_threads(%value : type)`
+ - Example: `num_threads(%n : i32)`
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if using multi-dimensional values (more than one value)
+ bool hasNumThreadsMultiDim() {
+ return getNumThreadsVars().size() > 1;
+ }
+
+ /// Returns the number of dimensions specified for num_threads
+ unsigned getNumThreadsDimsCount() {
+ return getNumThreadsVars().size();
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumThreadsVars().size()
+ ::mlir::Value getNumThreads(unsigned dim = 0) {
+ assert(dim < getNumThreadsDimsCount() &&
+ "Num threads index out of bounds");
+ return getNumThreadsVars()[dim];
+ }
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 5fcaea7f39c3c..48845734e9547 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -486,10 +486,11 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
}
rewriter.eraseOp(reduce);
- Value numThreadsVar;
+ SmallVector<Value> numThreadsVars;
if (numThreads > 0) {
- numThreadsVar = LLVM::ConstantOp::create(
+ Value numThreadsVar = LLVM::ConstantOp::create(
rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
+ numThreadsVars.push_back(numThreadsVar);
}
// Create the parallel wrapper.
auto ompParallel = omp::ParallelOp::create(
@@ -497,7 +498,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
- /* num_threads = */ numThreadsVar,
+ /* num_threads_vars = */ numThreadsVars,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
/* private_needs_barrier = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 8d06a9c66808a..349fc300e644d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2252,7 +2252,7 @@ LogicalResult TargetOp::verifyRegions() {
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
parallelOp->isAncestor(capturedOp) &&
- hostEvalArg == parallelOp.getNumThreads())
+ llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
continue;
return emitOpError()
@@ -2504,7 +2504,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
- /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
+ /*num_threads_vars=*/ValueRange(),
+ /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
/*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
@@ -2516,7 +2517,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreads, clauses.privateVars,
+ clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
clauses.privateNeedsBarrier, clauses.procBindKind,
clauses.reductionMod, clauses.reductionVars,
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 5af1dea34f482..924108d5702f7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,6 +380,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.hasNumTeamsMultiDim())
result = todo("num_teams with multi-dimensional values");
};
+ auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
+ if (op.hasNumThreadsMultiDim())
+ result = todo("num_threads with multi-dimensional values");
+ };
LogicalResult result = success();
llvm::TypeSwitch<Operation &>(op)
@@ -431,6 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
.Case([&](omp::ParallelOp op) {
checkAllocate(op, result);
checkReduction(op, result);
+ checkNumThreads(op, result);
})
.Case([&](omp::SimdOp op) { checkReduction(op, result); })
.Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
@@ -3268,8 +3273,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (auto ifVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
- if (auto numThreadsVar = opInst.getNumThreads())
- numThreads = moduleTranslation.lookupValue(numThreadsVar);
+ if (!opInst.getNumThreadsVars().empty())
+ numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
if (auto bind = opInst.getProcBindKind())
pbKind = getProcBindKind(*bind);
@@ -6051,7 +6056,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
- if (parallelOp.getNumThreads() == blockArg)
+ if (!parallelOp.getNumThreadsVars().empty() &&
+ parallelOp.getNumThreads(0) == blockArg)
numThreads = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6170,8 +6176,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
threadLimit = teamsOp.getThreadLimit();
}
- if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
- numThreads = parallelOp.getNumThreads();
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+ if (!parallelOp.getNumThreadsVars().empty())
+ numThreads = parallelOp.getNumThreads(0);
+ }
}
// Handle clauses impacting the number of teams.
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3afab239e19c5..563729039c762 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -160,6 +160,18 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
+ // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}} : i64, i64)
+ omp.parallel num_threads(%n_i64, %n_i64 : i64, i64) {
+ omp.terminator
+ }
+
+ %n_i16 = arith.constant 8 : i16
+ // Test num_threads with mixed types.
+ // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}}, %{{.*}} : i32, i64, i16)
+ omp.parallel num_threads(%num_threads, %n_i64, %n_i16 : i32, i64, i16) {
+ omp.terminator
+ }
+
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
omp.terminator
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 09919ae4f5267..ca5ec559926a9 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -443,6 +443,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
// -----
+llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
+ // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
+ omp.parallel num_threads(%lb, %ub : i32, i32) {
+ omp.terminator
+ }
+ llvm.return
+}
+
+// -----
+
llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
// expected-error at below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
More information about the Mlir-commits
mailing list