[llvm-branch-commits] [flang] [mlir] [OpenMP][MLIR] Add num_threads clause with dims modifier support (PR #171767)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 16 21:08:09 PST 2026
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171767
>From 6093bdcf18e36ad0ef1b97c6c2cac8b8cd9000c3 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 11:56:58 +0530
Subject: [PATCH 1/7] [OpenMP][MLIR] Add num_threads clause with dims modifier
support
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 50 +++++++++++-
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 2 +
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 79 +++++++++++++++++--
mlir/test/Dialect/OpenMP/invalid.mlir | 33 +++++++-
mlir/test/Dialect/OpenMP/ops.mlir | 15 ++--
5 files changed, 163 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index d4640f254ed1f..aedfa05da1608 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,60 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
+ Variadic<AnyInteger>:$num_threads_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
- `num_threads` `(` $num_threads `:` type($num_threads) `)`
+ `num_threads` `(` custom<NumThreadsClause>(
+ $num_threads_dims, $num_threads_values, type($num_threads_values),
+ $num_threads, type($num_threads)
+ ) `)`
}];
let description = [{
- The optional `num_threads` parameter specifies the number of threads which
- should be used to execute the parallel region.
+ num_threads clause specifies the desired number of threads in the team
+ space formed by the construct on which it appears.
+
+ With dims modifier:
+ - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list)
+ - Specifies upper bounds for each dimension (all must have same type)
+ - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+ - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+ Without dims modifier:
+ - Uses `num_threads`
+ - If lower bound not specified, it defaults to upper bound value
+ - Format: `num_threads(bounds : type)`
+ - Example: `num_threads(%ub : i32)`
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasDimsModifier() {
+ return getNumThreadsDims().has_value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ unsigned getNumDimensions() {
+ if (!hasDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumThreadsDims());
+ }
+
+ /// Returns all dimension values as an operand range
+ ::mlir::OperandRange getDimensionValues() {
+ return getNumThreadsValues();
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumDimensions()
+ ::mlir::Value getDimensionValue(unsigned index) {
+ assert(index < getDimensionValues().size() &&
+ "Dimension index out of bounds");
+ return getDimensionValues()[index];
+ }
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..0d5333ec2e455 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
+ /* num_threads_dims = */ nullptr,
+ /* num_threads_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 67ff9023a38da..9664b8f59802c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2504,6 +2504,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+ /*num_threads_dims=*/nullptr,
+ /*num_threads_values=*/ValueRange(),
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
@@ -2515,13 +2517,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreads, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.procBindKind,
- clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms));
+ ParallelOp::build(
+ builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+ clauses.numThreads, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+ clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms));
}
template <typename OpType>
@@ -2568,13 +2571,40 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
}
LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ auto numThreadsDims = getNumThreadsDims();
+ auto numThreadsValues = getNumThreadsValues();
+ auto numThreads = getNumThreads();
+
+ // num_threads with dims modifier
+ if (numThreadsDims.has_value() && numThreadsValues.empty()) {
+ return emitError(
+ "num_threads dims modifier requires values to be specified");
+ }
+
+ if (numThreadsDims.has_value() &&
+ numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
+ return emitError("num_threads dims(")
+ << *numThreadsDims << ") specified but " << numThreadsValues.size()
+ << " values provided";
+ }
+
+ // num_threads dims and number of threads cannot be used together
+ if (numThreadsDims.has_value() && numThreads) {
+ return emitError(
+ "num_threads dims and number of threads cannot be used together");
+ }
+
+ // verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // verify private variables restrictions
if (failed(verifyPrivateVarList(*this)))
return failure();
+ // verify reduction variables restrictions
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
@@ -4595,6 +4625,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types,
+ std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+ Type &boundsType) {
+ if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand boundsOperand;
+ if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+ parser.parseType(boundsType)) {
+ return failure();
+ }
+ bounds = boundsOperand;
+ return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types, Value bounds,
+ Type boundsType) {
+ if (!values.empty()) {
+ printDimsModifierWithValues(p, dimsAttr, values, types);
+ }
+ if (bounds) {
+ p.printOperand(bounds);
+ p << " : " << boundsType;
+ }
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index bb882db73cbab..75431ec475954 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
// -----
+func.func @num_threads_dims_no_values() {
+ // expected-error at +1 {{num_threads dims modifier requires values to be specified}}
+ "omp.parallel"() ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> ()
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+ // expected-error at +1 {{num_threads dims(2) specified but 1 values provided}}
+ omp.parallel num_threads(dims(2): %n : i64) {
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+ // expected-error at +1 {{num_threads dims and number of threads cannot be used together}}
+ "omp.parallel"(%n, %n, %m) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> ()
+ return
+}
+
+// -----
+
func.func @nowait_not_allowed(%n : memref<i32>) {
// expected-error at +1 {{expected '{' to begin a region}}
omp.parallel nowait {}
@@ -2708,7 +2739,7 @@ func.func @undefined_privatizer(%arg0: index) {
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
- "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 89c7e5fd48bd9..3acbe010c28a5 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
"omp.parallel"(%data_var, %data_var, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
"omp.parallel"(%data_var, %data_var, %if_cond) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
// CHECK: omp.parallel
omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+ omp.terminator
+ }
+
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
omp.terminator
>From 97045e6201626b5f73e5178905a9a2cefa09b9cf Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 12:11:49 +0530
Subject: [PATCH 2/7] Mark mlir->llvmir translation for num_threads with dims
as NYI
---
.../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8a3a990e5a3fd..e66666b526069 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3268,6 +3268,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (auto ifVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
+ // num_threads dims and values are not yet supported
+ assert(!opInst.getNumThreadsDims().has_value() &&
+ opInst.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -6050,6 +6054,10 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.getNumThreadsDims().has_value() &&
+ parallelOp.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
else
@@ -6167,8 +6175,13 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
threadLimit = teamsOp.getThreadLimit();
}
- if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.getNumThreadsDims().has_value() &&
+ parallelOp.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
+ }
}
// Handle clauses impacting the number of teams.
>From 60288588459e658d9d2d1238569a19f34e932b80 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 17:37:52 +0530
Subject: [PATCH 3/7] few more fixes
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 33 ++++++--------
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 4 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 44 +++++++++----------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 9 ++--
mlir/test/Dialect/OpenMP/invalid.mlir | 10 ++---
5 files changed, 45 insertions(+), 55 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index aedfa05da1608..3559002c6473f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,14 +1069,14 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
- ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
- Variadic<AnyInteger>:$num_threads_values,
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+ Variadic<AnyInteger>:$num_threads_dims_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
`num_threads` `(` custom<NumThreadsClause>(
- $num_threads_dims, $num_threads_values, type($num_threads_values),
+ $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
$num_threads, type($num_threads)
) `)`
}];
@@ -1086,7 +1086,7 @@ class OpenMP_NumThreadsClauseSkip<
space formed by the construct on which it appears.
With dims modifier:
- - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list)
+ - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
- Specifies upper bounds for each dimension (all must have same type)
- Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
- Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
@@ -1100,28 +1100,23 @@ class OpenMP_NumThreadsClauseSkip<
let extraClassDeclaration = [{
/// Returns true if the dims modifier is explicitly present
- bool hasDimsModifier() {
- return getNumThreadsDims().has_value();
+ bool hasNumThreadsDimsModifier() {
+ return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
}
/// Returns the number of dimensions specified by dims modifier
- unsigned getNumDimensions() {
- if (!hasDimsModifier())
+ unsigned getNumThreadsDimsCount() {
+ if (!hasNumThreadsDimsModifier())
return 1;
- return static_cast<unsigned>(*getNumThreadsDims());
- }
-
- /// Returns all dimension values as an operand range
- ::mlir::OperandRange getDimensionValues() {
- return getNumThreadsValues();
+ return static_cast<unsigned>(*getNumThreadsNumDims());
}
/// Returns the value for a specific dimension index
- /// Index must be less than getNumDimensions()
- ::mlir::Value getDimensionValue(unsigned index) {
- assert(index < getDimensionValues().size() &&
- "Dimension index out of bounds");
- return getDimensionValues()[index];
+ /// Index must be less than getNumThreadsDimsCount()
+ ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+ assert(index < getNumThreadsDimsCount() &&
+ "Num threads dims index out of bounds");
+ return getNumThreadsDimsValues()[index];
}
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 0d5333ec2e455..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,8 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
- /* num_threads_dims = */ nullptr,
- /* num_threads_values = */ llvm::SmallVector<Value>{},
+ /* num_threads_num_dims = */ nullptr,
+ /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 9664b8f59802c..54ce42f684581 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2519,7 +2519,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
ParallelOp::build(
builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+ clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
clauses.numThreads, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
@@ -2570,30 +2570,28 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
return success();
}
-LogicalResult ParallelOp::verify() {
- // verify num_threads clause restrictions
- auto numThreadsDims = getNumThreadsDims();
- auto numThreadsValues = getNumThreadsValues();
- auto numThreads = getNumThreads();
-
- // num_threads with dims modifier
- if (numThreadsDims.has_value() && numThreadsValues.empty()) {
- return emitError(
- "num_threads dims modifier requires values to be specified");
- }
-
- if (numThreadsDims.has_value() &&
- numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
- return emitError("num_threads dims(")
- << *numThreadsDims << ") specified but " << numThreadsValues.size()
- << " values provided";
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+ std::optional<IntegerAttr> numThreadsNumDims,
+ OperandRange numThreadsDimsValues, Value numThreads) {
+ bool hasDimsModifier =
+ numThreadsNumDims.has_value() && numThreadsNumDims.value();
+ if (hasDimsModifier && numThreads) {
+ return op->emitError("num_threads with dims modifier cannot be used "
+ "together with number of threads");
}
+ if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+ return failure();
+ return success();
+}
- // num_threads dims and number of threads cannot be used together
- if (numThreadsDims.has_value() && numThreads) {
- return emitError(
- "num_threads dims and number of threads cannot be used together");
- }
+LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ if (failed(verifyNumThreadsClause(
+ getOperation(), this->getNumThreadsNumDimsAttr(),
+ this->getNumThreadsDimsValues(), this->getNumThreads())))
+ return failure();
// verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index e66666b526069..67f30383bb03a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3269,8 +3269,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
// num_threads dims and values are not yet supported
- assert(!opInst.getNumThreadsDims().has_value() &&
- opInst.getNumThreadsValues().empty() &&
+ assert(!opInst.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
@@ -6055,8 +6054,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
})
.Case([&](omp::ParallelOp parallelOp) {
// num_threads dims and values are not yet supported
- assert(!parallelOp.getNumThreadsDims().has_value() &&
- parallelOp.getNumThreadsValues().empty() &&
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
@@ -6177,8 +6175,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
// num_threads dims and values are not yet supported
- assert(!parallelOp.getNumThreadsDims().has_value() &&
- parallelOp.getNumThreadsValues().empty() &&
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 75431ec475954..1c5ef785a17f9 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -31,17 +31,17 @@ func.func @num_threads_once(%n : si32) {
// -----
func.func @num_threads_dims_no_values() {
- // expected-error at +1 {{num_threads dims modifier requires values to be specified}}
+ // expected-error at +1 {{dims modifier requires values to be specified}}
"omp.parallel"() ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
return
}
// -----
func.func @num_threads_dims_mismatch(%n : i64) {
- // expected-error at +1 {{num_threads dims(2) specified but 1 values provided}}
+ // expected-error at +1 {{dims(2) specified but 1 values provided}}
omp.parallel num_threads(dims(2): %n : i64) {
omp.terminator
}
@@ -52,10 +52,10 @@ func.func @num_threads_dims_mismatch(%n : i64) {
// -----
func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
- // expected-error at +1 {{num_threads dims and number of threads cannot be used together}}
+ // expected-error at +1 {{num_threads with dims modifier cannot be used together with number of threads}}
"omp.parallel"(%n, %n, %m) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
return
}
>From f07a41aa54d4f27ced44bba8b013e12b4f5ba1dd Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 19 Dec 2025 12:27:38 +0530
Subject: [PATCH 4/7] Use num_threads_dims_values only
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 4 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 15 ++---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 15 +++--
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 5 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 62 ++++++++-----------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 ++---
mlir/test/Dialect/OpenMP/invalid.mlir | 12 ++--
mlir/test/Dialect/OpenMP/ops.mlir | 10 +--
8 files changed, 66 insertions(+), 73 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index b923e415231d6..abaeaa90f80be 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -516,8 +516,8 @@ bool ClauseProcessor::processNumThreads(
mlir::omp::NumThreadsClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result.numThreads =
- fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.numThreadsDimsValues.push_back(
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx)));
return true;
}
return false;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 989e370870f33..bdbabc292349a 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -99,8 +99,8 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
vars.push_back(ops.numTeamsUpper);
- if (ops.numThreads)
- vars.push_back(ops.numThreads);
+ for (auto numThreads : ops.numThreadsDimsValues)
+ vars.push_back(numThreads);
if (ops.threadLimit)
vars.push_back(ops.threadLimit);
@@ -115,7 +115,8 @@ class HostEvalInfo {
assert(args.size() ==
ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
- (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+ (ops.numTeamsUpper ? 1 : 0) +
+ ops.numThreadsDimsValues.size() +
(ops.threadLimit ? 1 : 0) &&
"invalid block argument list");
int argIndex = 0;
@@ -134,8 +135,8 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
ops.numTeamsUpper = args[argIndex++];
- if (ops.numThreads)
- ops.numThreads = args[argIndex++];
+ for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i)
+ ops.numThreadsDimsValues[i] = args[argIndex++];
if (ops.threadLimit)
ops.threadLimit = args[argIndex++];
@@ -169,13 +170,13 @@ class HostEvalInfo {
/// \returns whether an update was performed. If not, these clauses were not
/// evaluated in the host device.
bool apply(mlir::omp::ParallelOperands &clauseOps) {
- if (!ops.numThreads || parallelApplied) {
+ if (ops.numThreadsDimsValues.empty() || parallelApplied) {
parallelApplied = true;
return false;
}
parallelApplied = true;
- clauseOps.numThreads = ops.numThreads;
+ clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
return true;
}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 3559002c6473f..8be7030599cc6 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1070,14 +1070,12 @@ class OpenMP_NumThreadsClauseSkip<
extraClassDeclaration> {
let arguments = (ins
ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
- Variadic<AnyInteger>:$num_threads_dims_values,
- Optional<IntLikeType>:$num_threads
+ Variadic<IntLikeType>:$num_threads_dims_values
);
let optAssemblyFormat = [{
`num_threads` `(` custom<NumThreadsClause>(
- $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
- $num_threads, type($num_threads)
+ $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values)
) `)`
}];
@@ -1092,10 +1090,9 @@ class OpenMP_NumThreadsClauseSkip<
- Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
Without dims modifier:
- - Uses `num_threads`
- - If lower bound not specified, it defaults to upper bound value
- - Format: `num_threads(bounds : type)`
- - Example: `num_threads(%ub : i32)`
+ - The number of threads is specified by single value in `num_threads_dims_values`
+ - Format: `num_threads(value : type)`
+ - Example: `num_threads(%n : i32)`
}];
let extraClassDeclaration = [{
@@ -1116,6 +1113,8 @@ class OpenMP_NumThreadsClauseSkip<
::mlir::Value getNumThreadsDimsValue(unsigned index) {
assert(index < getNumThreadsDimsCount() &&
"Num threads dims index out of bounds");
+ if(getNumThreadsDimsValues().empty())
+ return nullptr;
return getNumThreadsDimsValues()[index];
}
}];
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ab7bded7835be..5d75613f9b2b6 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -438,9 +438,11 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
rewriter.eraseOp(reduce);
Value numThreadsVar;
+ SmallVector<Value> numThreadsValues;
if (numThreads > 0) {
numThreadsVar = LLVM::ConstantOp::create(
rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
+ numThreadsValues.push_back(numThreadsVar);
}
// Create the parallel wrapper.
auto ompParallel = omp::ParallelOp::create(
@@ -449,8 +451,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
/* num_threads_num_dims = */ nullptr,
- /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
- /* num_threads = */ numThreadsVar,
+ /* num_threads_dims_values = */ numThreadsValues,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
/* private_needs_barrier = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 54ce42f684581..6911272d43f6e 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2252,7 +2252,8 @@ LogicalResult TargetOp::verifyRegions() {
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
parallelOp->isAncestor(capturedOp) &&
- hostEvalArg == parallelOp.getNumThreads())
+ llvm::is_contained(parallelOp.getNumThreadsDimsValues(),
+ hostEvalArg))
continue;
return emitOpError()
@@ -2506,7 +2507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
/*num_threads_dims=*/nullptr,
/*num_threads_values=*/ValueRange(),
- /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
+ /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
/*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
@@ -2517,14 +2518,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- ParallelOp::build(
- builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
- clauses.numThreads, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
- clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms));
+ ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numThreadsNumDims,
+ clauses.numThreadsDimsValues, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privateSyms),
+ clauses.privateNeedsBarrier, clauses.procBindKind,
+ clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms));
}
template <typename OpType>
@@ -2574,13 +2575,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
LogicalResult
verifyNumThreadsClause(Operation *op,
std::optional<IntegerAttr> numThreadsNumDims,
- OperandRange numThreadsDimsValues, Value numThreads) {
- bool hasDimsModifier =
- numThreadsNumDims.has_value() && numThreadsNumDims.value();
- if (hasDimsModifier && numThreads) {
- return op->emitError("num_threads with dims modifier cannot be used "
- "together with number of threads");
- }
+ OperandRange numThreadsDimsValues) {
if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
return failure();
return success();
@@ -2588,9 +2583,9 @@ verifyNumThreadsClause(Operation *op,
LogicalResult ParallelOp::verify() {
// verify num_threads clause restrictions
- if (failed(verifyNumThreadsClause(
- getOperation(), this->getNumThreadsNumDimsAttr(),
- this->getNumThreadsDimsValues(), this->getNumThreads())))
+ if (failed(verifyNumThreadsClause(getOperation(),
+ this->getNumThreadsNumDimsAttr(),
+ this->getNumThreadsDimsValues())))
return failure();
// verify allocate clause restrictions
@@ -4629,33 +4624,28 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
static ParseResult
parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- SmallVectorImpl<Type> &types,
- std::optional<OpAsmParser::UnresolvedOperand> &bounds,
- Type &boundsType) {
+ SmallVectorImpl<Type> &types) {
if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
return success();
}
- OpAsmParser::UnresolvedOperand boundsOperand;
- if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
- parser.parseType(boundsType)) {
+ // Without dims modifier: value : type
+ OpAsmParser::UnresolvedOperand singleValue;
+ Type singleType;
+ if (parser.parseOperand(singleValue) || parser.parseColon() ||
+ parser.parseType(singleType)) {
return failure();
}
- bounds = boundsOperand;
+ values.push_back(singleValue);
+ types.push_back(singleType);
return success();
}
static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
IntegerAttr dimsAttr, OperandRange values,
- TypeRange types, Value bounds,
- Type boundsType) {
- if (!values.empty()) {
- printDimsModifierWithValues(p, dimsAttr, values, types);
- }
- if (bounds) {
- p.printOperand(bounds);
- p << " : " << boundsType;
- }
+ TypeRange types) {
+ // Multidimensional: dims(N): values : type
+ printDimsModifierWithValues(p, dimsAttr, values, types);
}
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 67f30383bb03a..da44dda0a1230 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3270,8 +3270,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
llvm::Value *numThreads = nullptr;
// num_threads dims and values are not yet supported
assert(!opInst.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is NYI.");
- if (auto numThreadsVar = opInst.getNumThreads())
+ "Lowering of num_threads with dims modifier is not yet implemented.");
+ if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0))
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
if (auto bind = opInst.getProcBindKind())
@@ -6055,8 +6055,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
.Case([&](omp::ParallelOp parallelOp) {
// num_threads dims and values are not yet supported
assert(!parallelOp.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is NYI.");
- if (parallelOp.getNumThreads() == blockArg)
+ "Lowering of num_threads with dims modifier is not yet "
+ "implemented.");
+ if (parallelOp.getNumThreadsDimsValue(0) == blockArg)
numThreads = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6175,9 +6176,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
// num_threads dims and values are not yet supported
- assert(!parallelOp.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is NYI.");
- numThreads = parallelOp.getNumThreads();
+ assert(
+ !parallelOp.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is not yet implemented.");
+ numThreads = parallelOp.getNumThreadsDimsValue(0);
}
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 1c5ef785a17f9..8a5e64b1a98ca 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -34,7 +34,7 @@ func.func @num_threads_dims_no_values() {
// expected-error at +1 {{dims modifier requires values to be specified}}
"omp.parallel"() ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
return
}
@@ -51,11 +51,11 @@ func.func @num_threads_dims_mismatch(%n : i64) {
// -----
-func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
- // expected-error at +1 {{num_threads with dims modifier cannot be used together with number of threads}}
- "omp.parallel"(%n, %n, %m) ({
+func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) {
+ // expected-error at +1 {{dims values can only be specified with dims modifier}}
+ "omp.parallel"(%n, %m) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> ()
return
}
@@ -2739,7 +2739,7 @@ func.func @undefined_privatizer(%arg0: index) {
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
- "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3acbe010c28a5..4c57b8aea0b48 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
"omp.parallel"(%data_var, %data_var, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
"omp.parallel"(%data_var, %data_var, %if_cond) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
// CHECK: omp.parallel
omp.parallel {
>From 038f9f4b3cfd4664f4df95e141178c6289194ac4 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 Jan 2026 12:07:56 +0530
Subject: [PATCH 5/7] fix adding numThreadsNumDims to ParallelOperands apply
method
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index bdbabc292349a..5ca228e218c37 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -177,6 +177,7 @@ class HostEvalInfo {
parallelApplied = true;
clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
+ clauseOps.numThreadsNumDims = ops.numThreadsNumDims;
return true;
}
>From 12c4749a7dc638ea4f22f2e1dd9cf9fd987f5123 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 16 Jan 2026 12:32:56 +0530
Subject: [PATCH 6/7] Remove dims(N) syntax and use list of vals for
num_threads
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 2 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 14 +++--
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 53 +++++++++----------
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 3 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 51 +++++-------------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 26 ++++-----
mlir/test/Dialect/OpenMP/invalid.mlir | 31 -----------
mlir/test/Dialect/OpenMP/ops.mlir | 11 +++-
mlir/test/Target/LLVMIR/openmp-todo.mlir | 11 ++++
9 files changed, 76 insertions(+), 126 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index abaeaa90f80be..90825a3653016 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -516,7 +516,7 @@ bool ClauseProcessor::processNumThreads(
mlir::omp::NumThreadsClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result.numThreadsDimsValues.push_back(
+ result.numThreadsVals.push_back(
fir::getBase(converter.genExprValue(clause->v, stmtCtx)));
return true;
}
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 5ca228e218c37..c9271925580cd 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -99,7 +99,7 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
vars.push_back(ops.numTeamsUpper);
- for (auto numThreads : ops.numThreadsDimsValues)
+ for (auto numThreads : ops.numThreadsVals)
vars.push_back(numThreads);
if (ops.threadLimit)
@@ -115,8 +115,7 @@ class HostEvalInfo {
assert(args.size() ==
ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
- (ops.numTeamsUpper ? 1 : 0) +
- ops.numThreadsDimsValues.size() +
+ (ops.numTeamsUpper ? 1 : 0) + ops.numThreadsVals.size() +
(ops.threadLimit ? 1 : 0) &&
"invalid block argument list");
int argIndex = 0;
@@ -135,8 +134,8 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
ops.numTeamsUpper = args[argIndex++];
- for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i)
- ops.numThreadsDimsValues[i] = args[argIndex++];
+ for (size_t i = 0; i < ops.numThreadsVals.size(); ++i)
+ ops.numThreadsVals[i] = args[argIndex++];
if (ops.threadLimit)
ops.threadLimit = args[argIndex++];
@@ -170,14 +169,13 @@ class HostEvalInfo {
/// \returns whether an update was performed. If not, these clauses were not
/// evaluated in the host device.
bool apply(mlir::omp::ParallelOperands &clauseOps) {
- if (ops.numThreadsDimsValues.empty() || parallelApplied) {
+ if (ops.numThreadsVals.empty() || parallelApplied) {
parallelApplied = true;
return false;
}
parallelApplied = true;
- clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
- clauseOps.numThreadsNumDims = ops.numThreadsNumDims;
+ clauseOps.numThreadsVals = ops.numThreadsVals;
return true;
}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 8be7030599cc6..90bff92fbc826 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,53 +1069,48 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
- ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
- Variadic<IntLikeType>:$num_threads_dims_values
+ Variadic<IntLikeType>:$num_threads_vals
);
let optAssemblyFormat = [{
`num_threads` `(` custom<NumThreadsClause>(
- $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values)
+ $num_threads_vals, type($num_threads_vals)
) `)`
}];
let description = [{
- num_threads clause specifies the desired number of threads in the team
- space formed by the construct on which it appears.
-
- With dims modifier:
- - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
- - Specifies upper bounds for each dimension (all must have same type)
- - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
- - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
-
- Without dims modifier:
- - The number of threads is specified by single value in `num_threads_dims_values`
- - Format: `num_threads(value : type)`
+ The `num_threads` clause specifies the number of threads.
+
+ Multi-dimensional format (dims modifier):
+ - Multiple values can be specified for multi-dimensional thread counts.
+ - The number of dimensions is derived from the number of values.
+ - Values can have different integer types.
+ - Format: `num_threads(%v1, %v2, ... : type1, type2, ...)`
+ - Example: `num_threads(%n, %m : i32, i64)`
+
+ Single value format:
+ - A single value specifies the number of threads.
+ - Format: `num_threads(%value : type)`
- Example: `num_threads(%n : i32)`
}];
let extraClassDeclaration = [{
- /// Returns true if the dims modifier is explicitly present
- bool hasNumThreadsDimsModifier() {
- return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
+ /// Returns true if using multi-dimensional values (more than one value)
+ bool hasNumThreadsMultiDim() {
+ return getNumThreadsVals().size() > 1;
}
- /// Returns the number of dimensions specified by dims modifier
+ /// Returns the number of dimensions specified for num_threads
unsigned getNumThreadsDimsCount() {
- if (!hasNumThreadsDimsModifier())
- return 1;
- return static_cast<unsigned>(*getNumThreadsNumDims());
+ return getNumThreadsVals().size();
}
/// Returns the value for a specific dimension index
- /// Index must be less than getNumThreadsDimsCount()
- ::mlir::Value getNumThreadsDimsValue(unsigned index) {
- assert(index < getNumThreadsDimsCount() &&
- "Num threads dims index out of bounds");
- if(getNumThreadsDimsValues().empty())
- return nullptr;
- return getNumThreadsDimsValues()[index];
+ /// Index must be less than getNumThreadsVals().size()
+ ::mlir::Value getNumThreadsVal(unsigned index) {
+ assert(index < getNumThreadsVals().size() &&
+ "Num threads index out of bounds");
+ return getNumThreadsVals()[index];
}
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 5d75613f9b2b6..6ba2155c7840f 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -450,8 +450,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
- /* num_threads_num_dims = */ nullptr,
- /* num_threads_dims_values = */ numThreadsValues,
+ /* num_threads_vals = */ numThreadsValues,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
/* private_needs_barrier = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6911272d43f6e..bc7647d129f60 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2252,8 +2252,7 @@ LogicalResult TargetOp::verifyRegions() {
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
parallelOp->isAncestor(capturedOp) &&
- llvm::is_contained(parallelOp.getNumThreadsDimsValues(),
- hostEvalArg))
+ llvm::is_contained(parallelOp.getNumThreadsVals(), hostEvalArg))
continue;
return emitOpError()
@@ -2505,8 +2504,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
- /*num_threads_dims=*/nullptr,
- /*num_threads_values=*/ValueRange(),
+ /*num_threads_vals=*/ValueRange(),
/*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
@@ -2519,8 +2517,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreadsNumDims,
- clauses.numThreadsDimsValues, clauses.privateVars,
+ clauses.ifExpr, clauses.numThreadsVals, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
clauses.privateNeedsBarrier, clauses.procBindKind,
clauses.reductionMod, clauses.reductionVars,
@@ -2571,23 +2568,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
return success();
}
-// Helper: Verify num_threads clause
-LogicalResult
-verifyNumThreadsClause(Operation *op,
- std::optional<IntegerAttr> numThreadsNumDims,
- OperandRange numThreadsDimsValues) {
- if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
- return failure();
- return success();
-}
-
LogicalResult ParallelOp::verify() {
- // verify num_threads clause restrictions
- if (failed(verifyNumThreadsClause(getOperation(),
- this->getNumThreadsNumDimsAttr(),
- this->getNumThreadsDimsValues())))
- return failure();
-
// verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
return emitError(
@@ -4622,30 +4603,24 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
// Parser and printer for num_threads clause
//===----------------------------------------------------------------------===//
static ParseResult
-parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+parseNumThreadsClause(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
SmallVectorImpl<Type> &types) {
- if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
- return success();
- }
-
- // Without dims modifier: value : type
- OpAsmParser::UnresolvedOperand singleValue;
- Type singleType;
- if (parser.parseOperand(singleValue) || parser.parseColon() ||
- parser.parseType(singleType)) {
+ // Parse comma-separated list of values with their types
+ // Format: %v1, %v2, ... : type1, type2, ...
+ if (parser.parseOperandList(values) || parser.parseColon() ||
+ parser.parseTypeList(types)) {
return failure();
}
- values.push_back(singleValue);
- types.push_back(singleType);
return success();
}
static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
- IntegerAttr dimsAttr, OperandRange values,
- TypeRange types) {
- // Multidimensional: dims(N): values : type
- printDimsModifierWithValues(p, dimsAttr, values, types);
+ OperandRange values, TypeRange types) {
+ // Print values with their types
+ llvm::interleaveComma(values, p, [&](Value v) { p << v; });
+ p << " : ";
+ llvm::interleaveComma(types, p, [&](Type t) { p << t; });
}
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da44dda0a1230..2fd3da1b5b30a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,6 +380,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.hasNumTeamsMultiDim())
result = todo("num_teams with multi-dimensional values");
};
+ auto checkNumThreadsMultiDim = [&todo](auto op, LogicalResult &result) {
+ if (op.hasNumThreadsMultiDim())
+ result = todo("num_threads with multi-dimensional values");
+ };
LogicalResult result = success();
llvm::TypeSwitch<Operation &>(op)
@@ -431,6 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
.Case([&](omp::ParallelOp op) {
checkAllocate(op, result);
checkReduction(op, result);
+ checkNumThreadsMultiDim(op, result);
})
.Case([&](omp::SimdOp op) { checkReduction(op, result); })
.Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
@@ -3268,11 +3273,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (auto ifVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
- // num_threads dims and values are not yet supported
- assert(!opInst.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is not yet implemented.");
- if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0))
- numThreads = moduleTranslation.lookupValue(numThreadsVar);
+ if (!opInst.getNumThreadsVals().empty())
+ numThreads = moduleTranslation.lookupValue(opInst.getNumThreadsVal(0));
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
if (auto bind = opInst.getProcBindKind())
pbKind = getProcBindKind(*bind);
@@ -6053,11 +6055,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
- // num_threads dims and values are not yet supported
- assert(!parallelOp.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is not yet "
- "implemented.");
- if (parallelOp.getNumThreadsDimsValue(0) == blockArg)
+ if (!parallelOp.getNumThreadsVals().empty() &&
+ parallelOp.getNumThreadsVal(0) == blockArg)
numThreads = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6175,11 +6174,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
}
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
- // num_threads dims and values are not yet supported
- assert(
- !parallelOp.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is not yet implemented.");
- numThreads = parallelOp.getNumThreadsDimsValue(0);
+ if (!parallelOp.getNumThreadsVals().empty())
+ numThreads = parallelOp.getNumThreadsVal(0);
}
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 8a5e64b1a98ca..bb882db73cbab 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,37 +30,6 @@ func.func @num_threads_once(%n : si32) {
// -----
-func.func @num_threads_dims_no_values() {
- // expected-error at +1 {{dims modifier requires values to be specified}}
- "omp.parallel"() ({
- omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
- return
-}
-
-// -----
-
-func.func @num_threads_dims_mismatch(%n : i64) {
- // expected-error at +1 {{dims(2) specified but 1 values provided}}
- omp.parallel num_threads(dims(2): %n : i64) {
- omp.terminator
- }
-
- return
-}
-
-// -----
-
-func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) {
- // expected-error at +1 {{dims values can only be specified with dims modifier}}
- "omp.parallel"(%n, %m) ({
- omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> ()
- return
-}
-
-// -----
-
func.func @nowait_not_allowed(%n : memref<i32>) {
// expected-error at +1 {{expected '{' to begin a region}}
omp.parallel nowait {}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4c57b8aea0b48..67f93869d4be7 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -160,8 +160,15 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
- // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
- omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+ // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}} : i64, i64)
+ omp.parallel num_threads(%n_i64, %n_i64 : i64, i64) {
+ omp.terminator
+ }
+
+ %n_i16 = arith.constant 8 : i16
+ // Test num_threads with mixed types.
+ // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}}, %{{.*}} : i32, i64, i16)
+ omp.parallel num_threads(%num_threads, %n_i64, %n_i16 : i32, i64, i16) {
omp.terminator
}
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 3681ce38bd523..fd218e91d0b46 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -452,6 +452,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
// -----
+llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
+ // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
+ omp.parallel num_threads(%lb, %ub : i32, i32) {
+ omp.terminator
+ }
+ llvm.return
+}
+
+// -----
+
llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
// expected-error at below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
>From 1033cc66ab5617df178499b6138a6f00a7da18f5 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Sat, 17 Jan 2026 10:37:09 +0530
Subject: [PATCH 7/7] remove custom parser printer for num_threads
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 10 ++++----
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 24 -------------------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 10 ++++----
3 files changed, 9 insertions(+), 35 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 90bff92fbc826..7d0e1e3f91af4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1073,9 +1073,7 @@ class OpenMP_NumThreadsClauseSkip<
);
let optAssemblyFormat = [{
- `num_threads` `(` custom<NumThreadsClause>(
- $num_threads_vals, type($num_threads_vals)
- ) `)`
+ `num_threads` `(` $num_threads_vals `:` type($num_threads_vals) `)`
}];
let description = [{
@@ -1107,10 +1105,10 @@ class OpenMP_NumThreadsClauseSkip<
/// Returns the value for a specific dimension index
/// Index must be less than getNumThreadsVals().size()
- ::mlir::Value getNumThreadsVal(unsigned index) {
- assert(index < getNumThreadsVals().size() &&
+ ::mlir::Value getNumThreads(unsigned dim = 0) {
+ assert(dim < getNumThreadsDimsCount() &&
"Num threads index out of bounds");
- return getNumThreadsVals()[index];
+ return getNumThreadsVals()[dim];
}
}];
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index bc7647d129f60..ab1038c755f7a 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4599,30 +4599,6 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
-//===----------------------------------------------------------------------===//
-// Parser and printer for num_threads clause
-//===----------------------------------------------------------------------===//
-static ParseResult
-parseNumThreadsClause(OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- SmallVectorImpl<Type> &types) {
- // Parse comma-separated list of values with their types
- // Format: %v1, %v2, ... : type1, type2, ...
- if (parser.parseOperandList(values) || parser.parseColon() ||
- parser.parseTypeList(types)) {
- return failure();
- }
- return success();
-}
-
-static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
- OperandRange values, TypeRange types) {
- // Print values with their types
- llvm::interleaveComma(values, p, [&](Value v) { p << v; });
- p << " : ";
- llvm::interleaveComma(types, p, [&](Type t) { p << t; });
-}
-
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2fd3da1b5b30a..b92ec9332d43a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,7 +380,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.hasNumTeamsMultiDim())
result = todo("num_teams with multi-dimensional values");
};
- auto checkNumThreadsMultiDim = [&todo](auto op, LogicalResult &result) {
+ auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
if (op.hasNumThreadsMultiDim())
result = todo("num_threads with multi-dimensional values");
};
@@ -435,7 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
.Case([&](omp::ParallelOp op) {
checkAllocate(op, result);
checkReduction(op, result);
- checkNumThreadsMultiDim(op, result);
+ checkNumThreads(op, result);
})
.Case([&](omp::SimdOp op) { checkReduction(op, result); })
.Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
@@ -3274,7 +3274,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
if (!opInst.getNumThreadsVals().empty())
- numThreads = moduleTranslation.lookupValue(opInst.getNumThreadsVal(0));
+ numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
if (auto bind = opInst.getProcBindKind())
pbKind = getProcBindKind(*bind);
@@ -6056,7 +6056,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
})
.Case([&](omp::ParallelOp parallelOp) {
if (!parallelOp.getNumThreadsVals().empty() &&
- parallelOp.getNumThreadsVal(0) == blockArg)
+ parallelOp.getNumThreads(0) == blockArg)
numThreads = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6175,7 +6175,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
if (!parallelOp.getNumThreadsVals().empty())
- numThreads = parallelOp.getNumThreadsVal(0);
+ numThreads = parallelOp.getNumThreads(0);
}
}
More information about the llvm-branch-commits
mailing list