[llvm-branch-commits] [flang] [mlir] [OpenMP][MLIR] Add num_threads clause with dims modifier support (PR #171767)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Sat Jan 17 01:01:08 PST 2026
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171767
>From 77e758855c0e5cf3072704bbe461682ab192a84a Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 28 Nov 2025 13:37:14 +0530
Subject: [PATCH 1/8] [OpenMP][MLIR] Add num_teams clause with dims modifier
support
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 72 +++++++++++++++++++
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 3 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 5 ++
mlir/test/Dialect/OpenMP/invalid.mlir | 19 +----
4 files changed, 80 insertions(+), 19 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index b612d4e136baf..ed24530464ea4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1567,4 +1567,76 @@ class OpenMP_UseDevicePtrClauseSkip<
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V6.2: Multidimensional `num_teams` clause with dims modifier
+//===----------------------------------------------------------------------===//
+
+class OpenMP_NumTeamsMultiDimClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
+ Variadic<AnyInteger>:$num_teams_values
+ );
+
+ let optAssemblyFormat = [{
+ `num_teams_multi_dim` `(` custom<NumTeamsMultiDimClause>($num_teams_dims,
+ $num_teams_values,
+ type($num_teams_values)) `)`
+ }];
+
+ let description = [{
+ The `num_teams_multi_dim` clause with dims modifier support specifies the limit on
+ the number of teams to be created in a multidimensional team space.
+
+ The dims modifier for the num_teams_multi_dim clause specifies the number of
+ dimensions for the league space (team space) that the clause arranges.
+ The dimensions argument in the dims modifier specifies the number of
+ dimensions and determines the length of the list argument. The list items
+ are specified in ascending order according to the ordinal number of the
+ dimensions (dimension 0, 1, 2, ..., N-1).
+
+ - If `dims` is not specified: The space is unidimensional (1D) with a single value
+ - If `dims(1)` is specified: The space is explicitly unidimensional (1D)
+ - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D)
+
+ **Examples:**
+ - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a
+ 3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2.
+ - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt.
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasDimsModifier() {
+ return getNumTeamsDims().has_value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ /// Returns 1 if dims modifier is not present (unidimensional by default)
+ unsigned getNumDimensions() {
+ if (!hasDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumTeamsDims());
+ }
+
+ /// Returns all dimension values as an operand range
+ ::mlir::OperandRange getDimensionValues() {
+ return getNumTeamsValues();
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumDimensions()
+ ::mlir::Value getDimensionValue(unsigned index) {
+ assert(index < getDimensionValues().size() &&
+ "Dimension index out of bounds");
+ return getDimensionValues()[index];
+ }
+ }];
+}
+
+def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>;
+
#endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index d4e8cecda2601..76eeb0bd70ec3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,7 +241,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
], clauses = [
OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
- OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
+ OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
+ OpenMP_ThreadLimitClause
], singleRegion = true> {
let summary = "teams construct";
let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 25bf4e70d9a83..7a9a45b160ba3 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2625,8 +2625,13 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+<<<<<<< HEAD
clauses.ifExpr, clauses.numTeamsVals, clauses.numTeamsLower,
clauses.numTeamsUpper,
+=======
+ clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+ clauses.numTeamsDims, clauses.numTeamsValues,
+>>>>>>> [OpenMP][MLIR] Add num_teams clause with dims modifier support
/*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index d451b14e8bfc9..cd06011c2cbc4 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1451,24 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
- omp.terminator
- }
- return
-}
-
-// -----
-
-func.func @omp_teams_num_teams_multidim_with_bounds() {
- omp.target {
- %v0 = arith.constant 1 : i32
- %v1 = arith.constant 2 : i32
- %lb = arith.constant 3 : i32
- %ub = arith.constant 4 : i32
- // expected-error @below {{num_teams multi-dimensional values cannot be used together with legacy lower/upper bounds}}
- "omp.teams" (%v0, %v1, %lb, %ub) ({
- omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,2,1,1,0,0,0>} : (i32, i32, i32, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
omp.terminator
}
return
>From 2b170fd042b2bd0b2a147fe413780c45a0e4fbdb Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 11:56:58 +0530
Subject: [PATCH 2/8] [OpenMP][MLIR] Add num_threads clause with dims modifier
support
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 50 +++++++++++-
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 2 +
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 79 +++++++++++++++++--
mlir/test/Dialect/OpenMP/invalid.mlir | 33 +++++++-
mlir/test/Dialect/OpenMP/ops.mlir | 15 ++--
5 files changed, 163 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index ed24530464ea4..8826c15a15191 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,60 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
+ Variadic<AnyInteger>:$num_threads_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
- `num_threads` `(` $num_threads `:` type($num_threads) `)`
+ `num_threads` `(` custom<NumThreadsClause>(
+ $num_threads_dims, $num_threads_values, type($num_threads_values),
+ $num_threads, type($num_threads)
+ ) `)`
}];
let description = [{
- The optional `num_threads` parameter specifies the number of threads which
- should be used to execute the parallel region.
+ num_threads clause specifies the desired number of threads in the team
+ space formed by the construct on which it appears.
+
+ With dims modifier:
+ - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list)
+ - Specifies upper bounds for each dimension (all must have same type)
+ - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+ - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+ Without dims modifier:
+ - Uses `num_threads`
+ - If lower bound not specified, it defaults to upper bound value
+ - Format: `num_threads(bounds : type)`
+ - Example: `num_threads(%ub : i32)`
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasDimsModifier() {
+ return getNumThreadsDims().has_value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ unsigned getNumDimensions() {
+ if (!hasDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumThreadsDims());
+ }
+
+ /// Returns all dimension values as an operand range
+ ::mlir::OperandRange getDimensionValues() {
+ return getNumThreadsValues();
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumDimensions()
+ ::mlir::Value getDimensionValue(unsigned index) {
+ assert(index < getDimensionValues().size() &&
+ "Dimension index out of bounds");
+ return getDimensionValues()[index];
+ }
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 5fcaea7f39c3c..c749106b925f7 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -497,6 +497,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
+ /* num_threads_dims = */ nullptr,
+ /* num_threads_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 7a9a45b160ba3..d75b9e17f1e98 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2504,6 +2504,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+ /*num_threads_dims=*/nullptr,
+ /*num_threads_values=*/ValueRange(),
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
@@ -2515,13 +2517,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreads, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.procBindKind,
- clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms));
+ ParallelOp::build(
+ builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+ clauses.numThreads, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+ clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms));
}
template <typename OpType>
@@ -2568,13 +2571,40 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
}
LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ auto numThreadsDims = getNumThreadsDims();
+ auto numThreadsValues = getNumThreadsValues();
+ auto numThreads = getNumThreads();
+
+ // num_threads with dims modifier
+ if (numThreadsDims.has_value() && numThreadsValues.empty()) {
+ return emitError(
+ "num_threads dims modifier requires values to be specified");
+ }
+
+ if (numThreadsDims.has_value() &&
+ numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
+ return emitError("num_threads dims(")
+ << *numThreadsDims << ") specified but " << numThreadsValues.size()
+ << " values provided";
+ }
+
+ // num_threads dims and number of threads cannot be used together
+ if (numThreadsDims.has_value() && numThreads) {
+ return emitError(
+ "num_threads dims and number of threads cannot be used together");
+ }
+
+ // verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // verify private variables restrictions
if (failed(verifyPrivateVarList(*this)))
return failure();
+ // verify reduction variables restrictions
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
@@ -4623,6 +4653,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types,
+ std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+ Type &boundsType) {
+ if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand boundsOperand;
+ if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+ parser.parseType(boundsType)) {
+ return failure();
+ }
+ bounds = boundsOperand;
+ return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types, Value bounds,
+ Type boundsType) {
+ if (!values.empty()) {
+ printDimsModifierWithValues(p, dimsAttr, values, types);
+ }
+ if (bounds) {
+ p.printOperand(bounds);
+ p << " : " << boundsType;
+ }
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index cd06011c2cbc4..e55fe3d0a1aec 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
// -----
+func.func @num_threads_dims_no_values() {
+ // expected-error at +1 {{num_threads dims modifier requires values to be specified}}
+ "omp.parallel"() ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> ()
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+ // expected-error at +1 {{num_threads dims(2) specified but 1 values provided}}
+ omp.parallel num_threads(dims(2): %n : i64) {
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+ // expected-error at +1 {{num_threads dims and number of threads cannot be used together}}
+ "omp.parallel"(%n, %n, %m) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> ()
+ return
+}
+
+// -----
+
func.func @nowait_not_allowed(%n : memref<i32>) {
// expected-error at +1 {{expected '{' to begin a region}}
omp.parallel nowait {}
@@ -2691,7 +2722,7 @@ func.func @undefined_privatizer(%arg0: index) {
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
- "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 49a88e0443e60..f9cfd400387a5 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
"omp.parallel"(%data_var, %data_var, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
"omp.parallel"(%data_var, %data_var, %if_cond) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
// CHECK: omp.parallel
omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+ omp.terminator
+ }
+
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
omp.terminator
>From eb17b261fd452be4edcd90ed460675d317ef9f79 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 12:11:49 +0530
Subject: [PATCH 3/8] Mark mlir->llvmir translation for num_threads with dims
as NYI
---
.../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0b7bf64cefe4c..9c176b56a4d5d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3268,6 +3268,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (auto ifVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
+ // num_threads dims and values are not yet supported
+ assert(!opInst.getNumThreadsDims().has_value() &&
+ opInst.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -6050,6 +6054,10 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.getNumThreadsDims().has_value() &&
+ parallelOp.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
else
@@ -6167,8 +6175,13 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
threadLimit = teamsOp.getThreadLimit();
}
- if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.getNumThreadsDims().has_value() &&
+ parallelOp.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
+ }
}
// Handle clauses impacting the number of teams.
>From 207eca2b904e1844d82bcffe1baccf690a7f4a1f Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 17:37:52 +0530
Subject: [PATCH 4/8] few more fixes
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 33 ++++++--------
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 4 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 44 +++++++++----------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 9 ++--
mlir/test/Dialect/OpenMP/invalid.mlir | 10 ++---
5 files changed, 45 insertions(+), 55 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 8826c15a15191..8d8db94630f84 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,14 +1069,14 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
- ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
- Variadic<AnyInteger>:$num_threads_values,
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+ Variadic<AnyInteger>:$num_threads_dims_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
`num_threads` `(` custom<NumThreadsClause>(
- $num_threads_dims, $num_threads_values, type($num_threads_values),
+ $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
$num_threads, type($num_threads)
) `)`
}];
@@ -1086,7 +1086,7 @@ class OpenMP_NumThreadsClauseSkip<
space formed by the construct on which it appears.
With dims modifier:
- - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list)
+ - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
- Specifies upper bounds for each dimension (all must have same type)
- Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
- Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
@@ -1100,28 +1100,23 @@ class OpenMP_NumThreadsClauseSkip<
let extraClassDeclaration = [{
/// Returns true if the dims modifier is explicitly present
- bool hasDimsModifier() {
- return getNumThreadsDims().has_value();
+ bool hasNumThreadsDimsModifier() {
+ return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
}
/// Returns the number of dimensions specified by dims modifier
- unsigned getNumDimensions() {
- if (!hasDimsModifier())
+ unsigned getNumThreadsDimsCount() {
+ if (!hasNumThreadsDimsModifier())
return 1;
- return static_cast<unsigned>(*getNumThreadsDims());
- }
-
- /// Returns all dimension values as an operand range
- ::mlir::OperandRange getDimensionValues() {
- return getNumThreadsValues();
+ return static_cast<unsigned>(*getNumThreadsNumDims());
}
/// Returns the value for a specific dimension index
- /// Index must be less than getNumDimensions()
- ::mlir::Value getDimensionValue(unsigned index) {
- assert(index < getDimensionValues().size() &&
- "Dimension index out of bounds");
- return getDimensionValues()[index];
+ /// Index must be less than getNumThreadsDimsCount()
+ ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+ assert(index < getNumThreadsDimsCount() &&
+ "Num threads dims index out of bounds");
+ return getNumThreadsDimsValues()[index];
}
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index c749106b925f7..f9c8cab9b3d7b 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -497,8 +497,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
- /* num_threads_dims = */ nullptr,
- /* num_threads_values = */ llvm::SmallVector<Value>{},
+ /* num_threads_num_dims = */ nullptr,
+ /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d75b9e17f1e98..c2aca0887e38d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2519,7 +2519,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
ParallelOp::build(
builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+ clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
clauses.numThreads, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
@@ -2570,30 +2570,28 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
return success();
}
-LogicalResult ParallelOp::verify() {
- // verify num_threads clause restrictions
- auto numThreadsDims = getNumThreadsDims();
- auto numThreadsValues = getNumThreadsValues();
- auto numThreads = getNumThreads();
-
- // num_threads with dims modifier
- if (numThreadsDims.has_value() && numThreadsValues.empty()) {
- return emitError(
- "num_threads dims modifier requires values to be specified");
- }
-
- if (numThreadsDims.has_value() &&
- numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
- return emitError("num_threads dims(")
- << *numThreadsDims << ") specified but " << numThreadsValues.size()
- << " values provided";
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+ std::optional<IntegerAttr> numThreadsNumDims,
+ OperandRange numThreadsDimsValues, Value numThreads) {
+ bool hasDimsModifier =
+ numThreadsNumDims.has_value() && numThreadsNumDims.value();
+ if (hasDimsModifier && numThreads) {
+ return op->emitError("num_threads with dims modifier cannot be used "
+ "together with number of threads");
}
+ if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+ return failure();
+ return success();
+}
- // num_threads dims and number of threads cannot be used together
- if (numThreadsDims.has_value() && numThreads) {
- return emitError(
- "num_threads dims and number of threads cannot be used together");
- }
+LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ if (failed(verifyNumThreadsClause(
+ getOperation(), this->getNumThreadsNumDimsAttr(),
+ this->getNumThreadsDimsValues(), this->getNumThreads())))
+ return failure();
// verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9c176b56a4d5d..2d71910e27a52 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3269,8 +3269,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
// num_threads dims and values are not yet supported
- assert(!opInst.getNumThreadsDims().has_value() &&
- opInst.getNumThreadsValues().empty() &&
+ assert(!opInst.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
@@ -6055,8 +6054,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
})
.Case([&](omp::ParallelOp parallelOp) {
// num_threads dims and values are not yet supported
- assert(!parallelOp.getNumThreadsDims().has_value() &&
- parallelOp.getNumThreadsValues().empty() &&
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
@@ -6177,8 +6175,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
// num_threads dims and values are not yet supported
- assert(!parallelOp.getNumThreadsDims().has_value() &&
- parallelOp.getNumThreadsValues().empty() &&
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e55fe3d0a1aec..17985651a1286 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -31,17 +31,17 @@ func.func @num_threads_once(%n : si32) {
// -----
func.func @num_threads_dims_no_values() {
- // expected-error at +1 {{num_threads dims modifier requires values to be specified}}
+ // expected-error at +1 {{dims modifier requires values to be specified}}
"omp.parallel"() ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
return
}
// -----
func.func @num_threads_dims_mismatch(%n : i64) {
- // expected-error at +1 {{num_threads dims(2) specified but 1 values provided}}
+ // expected-error at +1 {{dims(2) specified but 1 values provided}}
omp.parallel num_threads(dims(2): %n : i64) {
omp.terminator
}
@@ -52,10 +52,10 @@ func.func @num_threads_dims_mismatch(%n : i64) {
// -----
func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
- // expected-error at +1 {{num_threads dims and number of threads cannot be used together}}
+ // expected-error at +1 {{num_threads with dims modifier cannot be used together with number of threads}}
"omp.parallel"(%n, %n, %m) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
return
}
>From cad7b45c8ba7d37198eb2a79d98288d5cb0ed45b Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 19 Dec 2025 12:27:38 +0530
Subject: [PATCH 5/8] Use num_threads_dims_values only
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 4 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 15 ++---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 15 +++--
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 5 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 62 ++++++++-----------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 ++---
mlir/test/Dialect/OpenMP/invalid.mlir | 12 ++--
mlir/test/Dialect/OpenMP/ops.mlir | 10 +--
8 files changed, 66 insertions(+), 73 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 2f531efaf09aa..8a96872294124 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -516,8 +516,8 @@ bool ClauseProcessor::processNumThreads(
mlir::omp::NumThreadsClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result.numThreads =
- fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.numThreadsDimsValues.push_back(
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx)));
return true;
}
return false;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 0764693f748a5..7b12750eebb4f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -99,8 +99,8 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
vars.push_back(ops.numTeamsUpper);
- if (ops.numThreads)
- vars.push_back(ops.numThreads);
+ for (auto numThreads : ops.numThreadsDimsValues)
+ vars.push_back(numThreads);
if (ops.threadLimit)
vars.push_back(ops.threadLimit);
@@ -115,7 +115,8 @@ class HostEvalInfo {
assert(args.size() ==
ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
- (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+ (ops.numTeamsUpper ? 1 : 0) +
+ ops.numThreadsDimsValues.size() +
(ops.threadLimit ? 1 : 0) &&
"invalid block argument list");
int argIndex = 0;
@@ -134,8 +135,8 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
ops.numTeamsUpper = args[argIndex++];
- if (ops.numThreads)
- ops.numThreads = args[argIndex++];
+ for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i)
+ ops.numThreadsDimsValues[i] = args[argIndex++];
if (ops.threadLimit)
ops.threadLimit = args[argIndex++];
@@ -169,13 +170,13 @@ class HostEvalInfo {
/// \returns whether an update was performed. If not, these clauses were not
/// evaluated in the host device.
bool apply(mlir::omp::ParallelOperands &clauseOps) {
- if (!ops.numThreads || parallelApplied) {
+ if (ops.numThreadsDimsValues.empty() || parallelApplied) {
parallelApplied = true;
return false;
}
parallelApplied = true;
- clauseOps.numThreads = ops.numThreads;
+ clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
return true;
}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 8d8db94630f84..10aaab4b6f21c 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1070,14 +1070,12 @@ class OpenMP_NumThreadsClauseSkip<
extraClassDeclaration> {
let arguments = (ins
ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
- Variadic<AnyInteger>:$num_threads_dims_values,
- Optional<IntLikeType>:$num_threads
+ Variadic<IntLikeType>:$num_threads_dims_values
);
let optAssemblyFormat = [{
`num_threads` `(` custom<NumThreadsClause>(
- $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
- $num_threads, type($num_threads)
+ $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values)
) `)`
}];
@@ -1092,10 +1090,9 @@ class OpenMP_NumThreadsClauseSkip<
- Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
Without dims modifier:
- - Uses `num_threads`
- - If lower bound not specified, it defaults to upper bound value
- - Format: `num_threads(bounds : type)`
- - Example: `num_threads(%ub : i32)`
+ - The number of threads is specified by single value in `num_threads_dims_values`
+ - Format: `num_threads(value : type)`
+ - Example: `num_threads(%n : i32)`
}];
let extraClassDeclaration = [{
@@ -1116,6 +1113,8 @@ class OpenMP_NumThreadsClauseSkip<
::mlir::Value getNumThreadsDimsValue(unsigned index) {
assert(index < getNumThreadsDimsCount() &&
"Num threads dims index out of bounds");
+ if(getNumThreadsDimsValues().empty())
+ return nullptr;
return getNumThreadsDimsValues()[index];
}
}];
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index f9c8cab9b3d7b..3a1f311dd63f0 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -487,9 +487,11 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
rewriter.eraseOp(reduce);
Value numThreadsVar;
+ SmallVector<Value> numThreadsValues;
if (numThreads > 0) {
numThreadsVar = LLVM::ConstantOp::create(
rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
+ numThreadsValues.push_back(numThreadsVar);
}
// Create the parallel wrapper.
auto ompParallel = omp::ParallelOp::create(
@@ -498,8 +500,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
/* num_threads_num_dims = */ nullptr,
- /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
- /* num_threads = */ numThreadsVar,
+ /* num_threads_dims_values = */ numThreadsValues,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
/* private_needs_barrier = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c2aca0887e38d..9366f04e51629 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2252,7 +2252,8 @@ LogicalResult TargetOp::verifyRegions() {
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
parallelOp->isAncestor(capturedOp) &&
- hostEvalArg == parallelOp.getNumThreads())
+ llvm::is_contained(parallelOp.getNumThreadsDimsValues(),
+ hostEvalArg))
continue;
return emitOpError()
@@ -2506,7 +2507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
/*num_threads_dims=*/nullptr,
/*num_threads_values=*/ValueRange(),
- /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
+ /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
/*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
@@ -2517,14 +2518,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- ParallelOp::build(
- builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
- clauses.numThreads, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
- clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms));
+ ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numThreadsNumDims,
+ clauses.numThreadsDimsValues, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privateSyms),
+ clauses.privateNeedsBarrier, clauses.procBindKind,
+ clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms));
}
template <typename OpType>
@@ -2574,13 +2575,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
LogicalResult
verifyNumThreadsClause(Operation *op,
std::optional<IntegerAttr> numThreadsNumDims,
- OperandRange numThreadsDimsValues, Value numThreads) {
- bool hasDimsModifier =
- numThreadsNumDims.has_value() && numThreadsNumDims.value();
- if (hasDimsModifier && numThreads) {
- return op->emitError("num_threads with dims modifier cannot be used "
- "together with number of threads");
- }
+ OperandRange numThreadsDimsValues) {
if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
return failure();
return success();
@@ -2588,9 +2583,9 @@ verifyNumThreadsClause(Operation *op,
LogicalResult ParallelOp::verify() {
// verify num_threads clause restrictions
- if (failed(verifyNumThreadsClause(
- getOperation(), this->getNumThreadsNumDimsAttr(),
- this->getNumThreadsDimsValues(), this->getNumThreads())))
+ if (failed(verifyNumThreadsClause(getOperation(),
+ this->getNumThreadsNumDimsAttr(),
+ this->getNumThreadsDimsValues())))
return failure();
// verify allocate clause restrictions
@@ -4657,33 +4652,28 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
static ParseResult
parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- SmallVectorImpl<Type> &types,
- std::optional<OpAsmParser::UnresolvedOperand> &bounds,
- Type &boundsType) {
+ SmallVectorImpl<Type> &types) {
if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
return success();
}
- OpAsmParser::UnresolvedOperand boundsOperand;
- if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
- parser.parseType(boundsType)) {
+ // Without dims modifier: value : type
+ OpAsmParser::UnresolvedOperand singleValue;
+ Type singleType;
+ if (parser.parseOperand(singleValue) || parser.parseColon() ||
+ parser.parseType(singleType)) {
return failure();
}
- bounds = boundsOperand;
+ values.push_back(singleValue);
+ types.push_back(singleType);
return success();
}
static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
IntegerAttr dimsAttr, OperandRange values,
- TypeRange types, Value bounds,
- Type boundsType) {
- if (!values.empty()) {
- printDimsModifierWithValues(p, dimsAttr, values, types);
- }
- if (bounds) {
- p.printOperand(bounds);
- p << " : " << boundsType;
- }
+ TypeRange types) {
+ // Multidimensional: dims(N): values : type
+ printDimsModifierWithValues(p, dimsAttr, values, types);
}
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2d71910e27a52..d4aaa832636d1 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3270,8 +3270,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
llvm::Value *numThreads = nullptr;
// num_threads dims and values are not yet supported
assert(!opInst.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is NYI.");
- if (auto numThreadsVar = opInst.getNumThreads())
+ "Lowering of num_threads with dims modifier is not yet implemented.");
+ if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0))
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
if (auto bind = opInst.getProcBindKind())
@@ -6055,8 +6055,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
.Case([&](omp::ParallelOp parallelOp) {
// num_threads dims and values are not yet supported
assert(!parallelOp.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is NYI.");
- if (parallelOp.getNumThreads() == blockArg)
+ "Lowering of num_threads with dims modifier is not yet "
+ "implemented.");
+ if (parallelOp.getNumThreadsDimsValue(0) == blockArg)
numThreads = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6175,9 +6176,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
// num_threads dims and values are not yet supported
- assert(!parallelOp.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is NYI.");
- numThreads = parallelOp.getNumThreads();
+ assert(
+ !parallelOp.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is not yet implemented.");
+ numThreads = parallelOp.getNumThreadsDimsValue(0);
}
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 17985651a1286..b2e20b4c5ee5a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -34,7 +34,7 @@ func.func @num_threads_dims_no_values() {
// expected-error at +1 {{dims modifier requires values to be specified}}
"omp.parallel"() ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
return
}
@@ -51,11 +51,11 @@ func.func @num_threads_dims_mismatch(%n : i64) {
// -----
-func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
- // expected-error at +1 {{num_threads with dims modifier cannot be used together with number of threads}}
- "omp.parallel"(%n, %n, %m) ({
+func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) {
+ // expected-error at +1 {{dims values can only be specified with dims modifier}}
+ "omp.parallel"(%n, %m) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> ()
return
}
@@ -2722,7 +2722,7 @@ func.func @undefined_privatizer(%arg0: index) {
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
- "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index f9cfd400387a5..e2a3f8fbe2d5f 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
"omp.parallel"(%data_var, %data_var, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
"omp.parallel"(%data_var, %data_var, %if_cond) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
// CHECK: omp.parallel
omp.parallel {
>From e2b12cf37fea11b6c129b40db66f8524a7cb467c Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 Jan 2026 12:07:56 +0530
Subject: [PATCH 6/8] fix adding numThreadsNumDims to ParallelOperands apply
method
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 7b12750eebb4f..8d03a04d87a21 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -177,6 +177,7 @@ class HostEvalInfo {
parallelApplied = true;
clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
+ clauseOps.numThreadsNumDims = ops.numThreadsNumDims;
return true;
}
>From 76d39229e04f303dc6b847c434e0904bece00f1a Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 16 Jan 2026 12:32:56 +0530
Subject: [PATCH 7/8] Remove dims(N) syntax and use list of vals for
num_threads
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 2 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 14 +++--
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 53 +++++++++----------
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 3 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 51 +++++-------------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 26 ++++-----
mlir/test/Dialect/OpenMP/invalid.mlir | 31 -----------
mlir/test/Dialect/OpenMP/ops.mlir | 11 +++-
mlir/test/Target/LLVMIR/openmp-todo.mlir | 11 ++++
9 files changed, 76 insertions(+), 126 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 8a96872294124..e33bdcc5c4dbd 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -516,7 +516,7 @@ bool ClauseProcessor::processNumThreads(
mlir::omp::NumThreadsClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result.numThreadsDimsValues.push_back(
+ result.numThreadsVals.push_back(
fir::getBase(converter.genExprValue(clause->v, stmtCtx)));
return true;
}
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 8d03a04d87a21..9947dcc8d5ebc 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -99,7 +99,7 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
vars.push_back(ops.numTeamsUpper);
- for (auto numThreads : ops.numThreadsDimsValues)
+ for (auto numThreads : ops.numThreadsVals)
vars.push_back(numThreads);
if (ops.threadLimit)
@@ -115,8 +115,7 @@ class HostEvalInfo {
assert(args.size() ==
ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
- (ops.numTeamsUpper ? 1 : 0) +
- ops.numThreadsDimsValues.size() +
+ (ops.numTeamsUpper ? 1 : 0) + ops.numThreadsVals.size() +
(ops.threadLimit ? 1 : 0) &&
"invalid block argument list");
int argIndex = 0;
@@ -135,8 +134,8 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
ops.numTeamsUpper = args[argIndex++];
- for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i)
- ops.numThreadsDimsValues[i] = args[argIndex++];
+ for (size_t i = 0; i < ops.numThreadsVals.size(); ++i)
+ ops.numThreadsVals[i] = args[argIndex++];
if (ops.threadLimit)
ops.threadLimit = args[argIndex++];
@@ -170,14 +169,13 @@ class HostEvalInfo {
/// \returns whether an update was performed. If not, these clauses were not
/// evaluated in the host device.
bool apply(mlir::omp::ParallelOperands &clauseOps) {
- if (ops.numThreadsDimsValues.empty() || parallelApplied) {
+ if (ops.numThreadsVals.empty() || parallelApplied) {
parallelApplied = true;
return false;
}
parallelApplied = true;
- clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
- clauseOps.numThreadsNumDims = ops.numThreadsNumDims;
+ clauseOps.numThreadsVals = ops.numThreadsVals;
return true;
}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 10aaab4b6f21c..cda6906d46965 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,53 +1069,48 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
- ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
- Variadic<IntLikeType>:$num_threads_dims_values
+ Variadic<IntLikeType>:$num_threads_vals
);
let optAssemblyFormat = [{
`num_threads` `(` custom<NumThreadsClause>(
- $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values)
+ $num_threads_vals, type($num_threads_vals)
) `)`
}];
let description = [{
- num_threads clause specifies the desired number of threads in the team
- space formed by the construct on which it appears.
-
- With dims modifier:
- - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
- - Specifies upper bounds for each dimension (all must have same type)
- - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
- - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
-
- Without dims modifier:
- - The number of threads is specified by single value in `num_threads_dims_values`
- - Format: `num_threads(value : type)`
+ The `num_threads` clause specifies the number of threads.
+
+ Multi-dimensional format (dims modifier):
+ - Multiple values can be specified for multi-dimensional thread counts.
+ - The number of dimensions is derived from the number of values.
+ - Values can have different integer types.
+ - Format: `num_threads(%v1, %v2, ... : type1, type2, ...)`
+ - Example: `num_threads(%n, %m : i32, i64)`
+
+ Single value format:
+ - A single value specifies the number of threads.
+ - Format: `num_threads(%value : type)`
- Example: `num_threads(%n : i32)`
}];
let extraClassDeclaration = [{
- /// Returns true if the dims modifier is explicitly present
- bool hasNumThreadsDimsModifier() {
- return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
+ /// Returns true if using multi-dimensional values (more than one value)
+ bool hasNumThreadsMultiDim() {
+ return getNumThreadsVals().size() > 1;
}
- /// Returns the number of dimensions specified by dims modifier
+ /// Returns the number of dimensions specified for num_threads
unsigned getNumThreadsDimsCount() {
- if (!hasNumThreadsDimsModifier())
- return 1;
- return static_cast<unsigned>(*getNumThreadsNumDims());
+ return getNumThreadsVals().size();
}
/// Returns the value for a specific dimension index
- /// Index must be less than getNumThreadsDimsCount()
- ::mlir::Value getNumThreadsDimsValue(unsigned index) {
- assert(index < getNumThreadsDimsCount() &&
- "Num threads dims index out of bounds");
- if(getNumThreadsDimsValues().empty())
- return nullptr;
- return getNumThreadsDimsValues()[index];
+ /// Index must be less than getNumThreadsVals().size()
+ ::mlir::Value getNumThreadsVal(unsigned index) {
+ assert(index < getNumThreadsVals().size() &&
+ "Num threads index out of bounds");
+ return getNumThreadsVals()[index];
}
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 3a1f311dd63f0..35288687a7eac 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -499,8 +499,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
- /* num_threads_num_dims = */ nullptr,
- /* num_threads_dims_values = */ numThreadsValues,
+ /* num_threads_vals = */ numThreadsValues,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
/* private_needs_barrier = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 9366f04e51629..65a006b48f480 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2252,8 +2252,7 @@ LogicalResult TargetOp::verifyRegions() {
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
parallelOp->isAncestor(capturedOp) &&
- llvm::is_contained(parallelOp.getNumThreadsDimsValues(),
- hostEvalArg))
+ llvm::is_contained(parallelOp.getNumThreadsVals(), hostEvalArg))
continue;
return emitOpError()
@@ -2505,8 +2504,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
- /*num_threads_dims=*/nullptr,
- /*num_threads_values=*/ValueRange(),
+ /*num_threads_vals=*/ValueRange(),
/*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
@@ -2519,8 +2517,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreadsNumDims,
- clauses.numThreadsDimsValues, clauses.privateVars,
+ clauses.ifExpr, clauses.numThreadsVals, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
clauses.privateNeedsBarrier, clauses.procBindKind,
clauses.reductionMod, clauses.reductionVars,
@@ -2571,23 +2568,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
return success();
}
-// Helper: Verify num_threads clause
-LogicalResult
-verifyNumThreadsClause(Operation *op,
- std::optional<IntegerAttr> numThreadsNumDims,
- OperandRange numThreadsDimsValues) {
- if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
- return failure();
- return success();
-}
-
LogicalResult ParallelOp::verify() {
- // verify num_threads clause restrictions
- if (failed(verifyNumThreadsClause(getOperation(),
- this->getNumThreadsNumDimsAttr(),
- this->getNumThreadsDimsValues())))
- return failure();
-
// verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
return emitError(
@@ -4650,30 +4631,24 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
// Parser and printer for num_threads clause
//===----------------------------------------------------------------------===//
static ParseResult
-parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+parseNumThreadsClause(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
SmallVectorImpl<Type> &types) {
- if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
- return success();
- }
-
- // Without dims modifier: value : type
- OpAsmParser::UnresolvedOperand singleValue;
- Type singleType;
- if (parser.parseOperand(singleValue) || parser.parseColon() ||
- parser.parseType(singleType)) {
+ // Parse comma-separated list of values with their types
+ // Format: %v1, %v2, ... : type1, type2, ...
+ if (parser.parseOperandList(values) || parser.parseColon() ||
+ parser.parseTypeList(types)) {
return failure();
}
- values.push_back(singleValue);
- types.push_back(singleType);
return success();
}
static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
- IntegerAttr dimsAttr, OperandRange values,
- TypeRange types) {
- // Multidimensional: dims(N): values : type
- printDimsModifierWithValues(p, dimsAttr, values, types);
+ OperandRange values, TypeRange types) {
+ // Print values with their types
+ llvm::interleaveComma(values, p, [&](Value v) { p << v; });
+ p << " : ";
+ llvm::interleaveComma(types, p, [&](Type t) { p << t; });
}
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d4aaa832636d1..a1cb06254f4b0 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,6 +380,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.hasNumTeamsMultiDim())
result = todo("num_teams with multi-dimensional values");
};
+ auto checkNumThreadsMultiDim = [&todo](auto op, LogicalResult &result) {
+ if (op.hasNumThreadsMultiDim())
+ result = todo("num_threads with multi-dimensional values");
+ };
LogicalResult result = success();
llvm::TypeSwitch<Operation &>(op)
@@ -431,6 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
.Case([&](omp::ParallelOp op) {
checkAllocate(op, result);
checkReduction(op, result);
+ checkNumThreadsMultiDim(op, result);
})
.Case([&](omp::SimdOp op) { checkReduction(op, result); })
.Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
@@ -3268,11 +3273,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (auto ifVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
- // num_threads dims and values are not yet supported
- assert(!opInst.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is not yet implemented.");
- if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0))
- numThreads = moduleTranslation.lookupValue(numThreadsVar);
+ if (!opInst.getNumThreadsVals().empty())
+ numThreads = moduleTranslation.lookupValue(opInst.getNumThreadsVal(0));
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
if (auto bind = opInst.getProcBindKind())
pbKind = getProcBindKind(*bind);
@@ -6053,11 +6055,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
- // num_threads dims and values are not yet supported
- assert(!parallelOp.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is not yet "
- "implemented.");
- if (parallelOp.getNumThreadsDimsValue(0) == blockArg)
+ if (!parallelOp.getNumThreadsVals().empty() &&
+ parallelOp.getNumThreadsVal(0) == blockArg)
numThreads = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6175,11 +6174,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
}
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
- // num_threads dims and values are not yet supported
- assert(
- !parallelOp.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is not yet implemented.");
- numThreads = parallelOp.getNumThreadsDimsValue(0);
+ if (!parallelOp.getNumThreadsVals().empty())
+ numThreads = parallelOp.getNumThreadsVal(0);
}
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index b2e20b4c5ee5a..cd06011c2cbc4 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,37 +30,6 @@ func.func @num_threads_once(%n : si32) {
// -----
-func.func @num_threads_dims_no_values() {
- // expected-error at +1 {{dims modifier requires values to be specified}}
- "omp.parallel"() ({
- omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
- return
-}
-
-// -----
-
-func.func @num_threads_dims_mismatch(%n : i64) {
- // expected-error at +1 {{dims(2) specified but 1 values provided}}
- omp.parallel num_threads(dims(2): %n : i64) {
- omp.terminator
- }
-
- return
-}
-
-// -----
-
-func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) {
- // expected-error at +1 {{dims values can only be specified with dims modifier}}
- "omp.parallel"(%n, %m) ({
- omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> ()
- return
-}
-
-// -----
-
func.func @nowait_not_allowed(%n : memref<i32>) {
// expected-error at +1 {{expected '{' to begin a region}}
omp.parallel nowait {}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index e2a3f8fbe2d5f..1700ad696f86f 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -160,8 +160,15 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
- // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
- omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+ // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}} : i64, i64)
+ omp.parallel num_threads(%n_i64, %n_i64 : i64, i64) {
+ omp.terminator
+ }
+
+ %n_i16 = arith.constant 8 : i16
+ // Test num_threads with mixed types.
+ // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}}, %{{.*}} : i32, i64, i16)
+ omp.parallel num_threads(%num_threads, %n_i64, %n_i16 : i32, i64, i16) {
omp.terminator
}
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 1ea56fdd0bf16..e4c47aae9b485 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -443,6 +443,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
// -----
+llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
+ // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
+ omp.parallel num_threads(%lb, %ub : i32, i32) {
+ omp.terminator
+ }
+ llvm.return
+}
+
+// -----
+
llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
// expected-error at below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
>From e1fc5f168954e48730538dd83a58a2603ce2ab3b Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Sat, 17 Jan 2026 10:37:09 +0530
Subject: [PATCH 8/8] remove custom parser printer for num_threads
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 82 +------------------
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 3 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 29 -------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 10 +--
mlir/test/Dialect/OpenMP/invalid.mlir | 19 ++++-
5 files changed, 28 insertions(+), 115 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index cda6906d46965..228e6e2deb1fb 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1073,9 +1073,7 @@ class OpenMP_NumThreadsClauseSkip<
);
let optAssemblyFormat = [{
- `num_threads` `(` custom<NumThreadsClause>(
- $num_threads_vals, type($num_threads_vals)
- ) `)`
+ `num_threads` `(` $num_threads_vals `:` type($num_threads_vals) `)`
}];
let description = [{
@@ -1107,10 +1105,10 @@ class OpenMP_NumThreadsClauseSkip<
/// Returns the value for a specific dimension index
/// Index must be less than getNumThreadsVals().size()
- ::mlir::Value getNumThreadsVal(unsigned index) {
- assert(index < getNumThreadsVals().size() &&
+ ::mlir::Value getNumThreads(unsigned dim = 0) {
+ assert(dim < getNumThreadsDimsCount() &&
"Num threads index out of bounds");
- return getNumThreadsVals()[index];
+ return getNumThreadsVals()[dim];
}
}];
}
@@ -1600,76 +1598,4 @@ class OpenMP_UseDevicePtrClauseSkip<
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
-//===----------------------------------------------------------------------===//
-// V6.2: Multidimensional `num_teams` clause with dims modifier
-//===----------------------------------------------------------------------===//
-
-class OpenMP_NumTeamsMultiDimClauseSkip<
- bit traits = false, bit arguments = false, bit assemblyFormat = false,
- bit description = false, bit extraClassDeclaration = false
- > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
- extraClassDeclaration> {
- let arguments = (ins
- ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
- Variadic<AnyInteger>:$num_teams_values
- );
-
- let optAssemblyFormat = [{
- `num_teams_multi_dim` `(` custom<NumTeamsMultiDimClause>($num_teams_dims,
- $num_teams_values,
- type($num_teams_values)) `)`
- }];
-
- let description = [{
- The `num_teams_multi_dim` clause with dims modifier support specifies the limit on
- the number of teams to be created in a multidimensional team space.
-
- The dims modifier for the num_teams_multi_dim clause specifies the number of
- dimensions for the league space (team space) that the clause arranges.
- The dimensions argument in the dims modifier specifies the number of
- dimensions and determines the length of the list argument. The list items
- are specified in ascending order according to the ordinal number of the
- dimensions (dimension 0, 1, 2, ..., N-1).
-
- - If `dims` is not specified: The space is unidimensional (1D) with a single value
- - If `dims(1)` is specified: The space is explicitly unidimensional (1D)
- - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D)
-
- **Examples:**
- - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a
- 3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2.
- - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt.
- }];
-
- let extraClassDeclaration = [{
- /// Returns true if the dims modifier is explicitly present
- bool hasDimsModifier() {
- return getNumTeamsDims().has_value();
- }
-
- /// Returns the number of dimensions specified by dims modifier
- /// Returns 1 if dims modifier is not present (unidimensional by default)
- unsigned getNumDimensions() {
- if (!hasDimsModifier())
- return 1;
- return static_cast<unsigned>(*getNumTeamsDims());
- }
-
- /// Returns all dimension values as an operand range
- ::mlir::OperandRange getDimensionValues() {
- return getNumTeamsValues();
- }
-
- /// Returns the value for a specific dimension index
- /// Index must be less than getNumDimensions()
- ::mlir::Value getDimensionValue(unsigned index) {
- assert(index < getDimensionValues().size() &&
- "Dimension index out of bounds");
- return getDimensionValues()[index];
- }
- }];
-}
-
-def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>;
-
#endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 76eeb0bd70ec3..d4e8cecda2601 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,8 +241,7 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
], clauses = [
OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
- OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
- OpenMP_ThreadLimitClause
+ OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
], singleRegion = true> {
let summary = "teams construct";
let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 65a006b48f480..4cdeaa0bc8e87 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2629,13 +2629,8 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-<<<<<<< HEAD
clauses.ifExpr, clauses.numTeamsVals, clauses.numTeamsLower,
clauses.numTeamsUpper,
-=======
- clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
- clauses.numTeamsDims, clauses.numTeamsValues,
->>>>>>> [OpenMP][MLIR] Add num_teams clause with dims modifier support
/*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
@@ -4627,30 +4622,6 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
-//===----------------------------------------------------------------------===//
-// Parser and printer for num_threads clause
-//===----------------------------------------------------------------------===//
-static ParseResult
-parseNumThreadsClause(OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- SmallVectorImpl<Type> &types) {
- // Parse comma-separated list of values with their types
- // Format: %v1, %v2, ... : type1, type2, ...
- if (parser.parseOperandList(values) || parser.parseColon() ||
- parser.parseTypeList(types)) {
- return failure();
- }
- return success();
-}
-
-static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
- OperandRange values, TypeRange types) {
- // Print values with their types
- llvm::interleaveComma(values, p, [&](Value v) { p << v; });
- p << " : ";
- llvm::interleaveComma(types, p, [&](Type t) { p << t; });
-}
-
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a1cb06254f4b0..73a91b3707c57 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,7 +380,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.hasNumTeamsMultiDim())
result = todo("num_teams with multi-dimensional values");
};
- auto checkNumThreadsMultiDim = [&todo](auto op, LogicalResult &result) {
+ auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
if (op.hasNumThreadsMultiDim())
result = todo("num_threads with multi-dimensional values");
};
@@ -435,7 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
.Case([&](omp::ParallelOp op) {
checkAllocate(op, result);
checkReduction(op, result);
- checkNumThreadsMultiDim(op, result);
+ checkNumThreads(op, result);
})
.Case([&](omp::SimdOp op) { checkReduction(op, result); })
.Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
@@ -3274,7 +3274,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
if (!opInst.getNumThreadsVals().empty())
- numThreads = moduleTranslation.lookupValue(opInst.getNumThreadsVal(0));
+ numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
if (auto bind = opInst.getProcBindKind())
pbKind = getProcBindKind(*bind);
@@ -6056,7 +6056,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
})
.Case([&](omp::ParallelOp parallelOp) {
if (!parallelOp.getNumThreadsVals().empty() &&
- parallelOp.getNumThreadsVal(0) == blockArg)
+ parallelOp.getNumThreads(0) == blockArg)
numThreads = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6175,7 +6175,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
if (!parallelOp.getNumThreadsVals().empty())
- numThreads = parallelOp.getNumThreadsVal(0);
+ numThreads = parallelOp.getNumThreads(0);
}
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index cd06011c2cbc4..d451b14e8bfc9 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1451,7 +1451,24 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_multidim_with_bounds() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ %lb = arith.constant 3 : i32
+ %ub = arith.constant 4 : i32
+ // expected-error @below {{num_teams multi-dimensional values cannot be used together with legacy lower/upper bounds}}
+ "omp.teams" (%v0, %v1, %lb, %ub) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,1,0,0,0>} : (i32, i32, i32, i32) -> ()
omp.terminator
}
return
More information about the llvm-branch-commits
mailing list