[llvm-branch-commits] [flang] [mlir] [OpenMP][MLIR] Add num_threads clause with dims modifier support (PR #171767)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 13 22:47:13 PST 2026
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171767
>From 44a528d177fcba049bc6a6a87962addebac8a443 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 11:56:58 +0530
Subject: [PATCH 1/5] [OpenMP][MLIR] Add num_threads clause with dims modifier
support
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 50 +++++++++++-
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 2 +
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 79 +++++++++++++++++--
mlir/test/Dialect/OpenMP/invalid.mlir | 33 +++++++-
mlir/test/Dialect/OpenMP/ops.mlir | 15 ++--
5 files changed, 163 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index b949e2629a095..56cfd016ef52b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,60 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
+ Variadic<AnyInteger>:$num_threads_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
- `num_threads` `(` $num_threads `:` type($num_threads) `)`
+ `num_threads` `(` custom<NumThreadsClause>(
+ $num_threads_dims, $num_threads_values, type($num_threads_values),
+ $num_threads, type($num_threads)
+ ) `)`
}];
let description = [{
- The optional `num_threads` parameter specifies the number of threads which
- should be used to execute the parallel region.
+ num_threads clause specifies the desired number of threads in the team
+ space formed by the construct on which it appears.
+
+ With dims modifier:
+ - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list)
+ - Specifies upper bounds for each dimension (all must have same type)
+ - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+ - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+ Without dims modifier:
+ - Uses `num_threads`
+ - If lower bound not specified, it defaults to upper bound value
+ - Format: `num_threads(bounds : type)`
+ - Example: `num_threads(%ub : i32)`
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasDimsModifier() {
+ return getNumThreadsDims().has_value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ unsigned getNumDimensions() {
+ if (!hasDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumThreadsDims());
+ }
+
+ /// Returns all dimension values as an operand range
+ ::mlir::OperandRange getDimensionValues() {
+ return getNumThreadsValues();
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumDimensions()
+ ::mlir::Value getDimensionValue(unsigned index) {
+ assert(index < getDimensionValues().size() &&
+ "Dimension index out of bounds");
+ return getDimensionValues()[index];
+ }
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..0d5333ec2e455 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
+ /* num_threads_dims = */ nullptr,
+ /* num_threads_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 03897c4c97df8..395cc36887b02 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2538,6 +2538,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+ /*num_threads_dims=*/nullptr,
+ /*num_threads_values=*/ValueRange(),
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
@@ -2549,13 +2551,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreads, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.procBindKind,
- clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms));
+ ParallelOp::build(
+ builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+ clauses.numThreads, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+ clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms));
}
template <typename OpType>
@@ -2602,13 +2605,40 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
}
LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ auto numThreadsDims = getNumThreadsDims();
+ auto numThreadsValues = getNumThreadsValues();
+ auto numThreads = getNumThreads();
+
+ // num_threads with dims modifier
+ if (numThreadsDims.has_value() && numThreadsValues.empty()) {
+ return emitError(
+ "num_threads dims modifier requires values to be specified");
+ }
+
+ if (numThreadsDims.has_value() &&
+ numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
+ return emitError("num_threads dims(")
+ << *numThreadsDims << ") specified but " << numThreadsValues.size()
+ << " values provided";
+ }
+
+ // num_threads dims and number of threads cannot be used together
+ if (numThreadsDims.has_value() && numThreads) {
+ return emitError(
+ "num_threads dims and number of threads cannot be used together");
+ }
+
+ // verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // verify private variables restrictions
if (failed(verifyPrivateVarList(*this)))
return failure();
+ // verify reduction variables restrictions
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
@@ -4659,6 +4689,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types,
+ std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+ Type &boundsType) {
+ if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand boundsOperand;
+ if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+ parser.parseType(boundsType)) {
+ return failure();
+ }
+ bounds = boundsOperand;
+ return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types, Value bounds,
+ Type boundsType) {
+ if (!values.empty()) {
+ printDimsModifierWithValues(p, dimsAttr, values, types);
+ }
+ if (bounds) {
+ p.printOperand(bounds);
+ p << " : " << boundsType;
+ }
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..9e2e5722aab9f 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
// -----
+func.func @num_threads_dims_no_values() {
+ // expected-error at +1 {{num_threads dims modifier requires values to be specified}}
+ "omp.parallel"() ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> ()
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+ // expected-error at +1 {{num_threads dims(2) specified but 1 values provided}}
+ omp.parallel num_threads(dims(2): %n : i64) {
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+ // expected-error at +1 {{num_threads dims and number of threads cannot be used together}}
+ "omp.parallel"(%n, %n, %m) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> ()
+ return
+}
+
+// -----
+
func.func @nowait_not_allowed(%n : memref<i32>) {
// expected-error at +1 {{expected '{' to begin a region}}
omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
- "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 24c2383f3c3aa..91cf899349c01 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
"omp.parallel"(%data_var, %data_var, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
"omp.parallel"(%data_var, %data_var, %if_cond) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
// CHECK: omp.parallel
omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+ omp.terminator
+ }
+
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
omp.terminator
>From 2a53df3a92fc6604cffc9f80bde71065dca1e0aa Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 12:11:49 +0530
Subject: [PATCH 2/5] Mark mlir->llvmir translation for num_threads with dims
as NYI
---
.../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 965a399fd653f..5072da9e0b9c7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3268,6 +3268,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (auto ifVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
+ // num_threads dims and values are not yet supported
+ assert(!opInst.getNumThreadsDims().has_value() &&
+ opInst.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -6054,6 +6058,10 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.getNumThreadsDims().has_value() &&
+ parallelOp.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
else
@@ -6175,8 +6183,13 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
threadLimit = teamsOp.getThreadLimit();
}
- if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.getNumThreadsDims().has_value() &&
+ parallelOp.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
+ }
}
// Handle clauses impacting the number of teams.
>From dcd8d2ad0f436dee3005c73e2e208e8590b30697 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 17:37:52 +0530
Subject: [PATCH 3/5] few more fixes
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 33 ++++++--------
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 4 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 44 +++++++++----------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 9 ++--
mlir/test/Dialect/OpenMP/invalid.mlir | 10 ++---
5 files changed, 45 insertions(+), 55 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 56cfd016ef52b..b5c51c5a8dff3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,14 +1069,14 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
- ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
- Variadic<AnyInteger>:$num_threads_values,
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+ Variadic<AnyInteger>:$num_threads_dims_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
`num_threads` `(` custom<NumThreadsClause>(
- $num_threads_dims, $num_threads_values, type($num_threads_values),
+ $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
$num_threads, type($num_threads)
) `)`
}];
@@ -1086,7 +1086,7 @@ class OpenMP_NumThreadsClauseSkip<
space formed by the construct on which it appears.
With dims modifier:
- - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list)
+ - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
- Specifies upper bounds for each dimension (all must have same type)
- Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
- Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
@@ -1100,28 +1100,23 @@ class OpenMP_NumThreadsClauseSkip<
let extraClassDeclaration = [{
/// Returns true if the dims modifier is explicitly present
- bool hasDimsModifier() {
- return getNumThreadsDims().has_value();
+ bool hasNumThreadsDimsModifier() {
+ return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
}
/// Returns the number of dimensions specified by dims modifier
- unsigned getNumDimensions() {
- if (!hasDimsModifier())
+ unsigned getNumThreadsDimsCount() {
+ if (!hasNumThreadsDimsModifier())
return 1;
- return static_cast<unsigned>(*getNumThreadsDims());
- }
-
- /// Returns all dimension values as an operand range
- ::mlir::OperandRange getDimensionValues() {
- return getNumThreadsValues();
+ return static_cast<unsigned>(*getNumThreadsNumDims());
}
/// Returns the value for a specific dimension index
- /// Index must be less than getNumDimensions()
- ::mlir::Value getDimensionValue(unsigned index) {
- assert(index < getDimensionValues().size() &&
- "Dimension index out of bounds");
- return getDimensionValues()[index];
+ /// Index must be less than getNumThreadsDimsCount()
+ ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+ assert(index < getNumThreadsDimsCount() &&
+ "Num threads dims index out of bounds");
+ return getNumThreadsDimsValues()[index];
}
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 0d5333ec2e455..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,8 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
- /* num_threads_dims = */ nullptr,
- /* num_threads_values = */ llvm::SmallVector<Value>{},
+ /* num_threads_num_dims = */ nullptr,
+ /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 395cc36887b02..4aa97a76c7cb4 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2553,7 +2553,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
ParallelOp::build(
builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+ clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
clauses.numThreads, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
@@ -2604,30 +2604,28 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
return success();
}
-LogicalResult ParallelOp::verify() {
- // verify num_threads clause restrictions
- auto numThreadsDims = getNumThreadsDims();
- auto numThreadsValues = getNumThreadsValues();
- auto numThreads = getNumThreads();
-
- // num_threads with dims modifier
- if (numThreadsDims.has_value() && numThreadsValues.empty()) {
- return emitError(
- "num_threads dims modifier requires values to be specified");
- }
-
- if (numThreadsDims.has_value() &&
- numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
- return emitError("num_threads dims(")
- << *numThreadsDims << ") specified but " << numThreadsValues.size()
- << " values provided";
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+ std::optional<IntegerAttr> numThreadsNumDims,
+ OperandRange numThreadsDimsValues, Value numThreads) {
+ bool hasDimsModifier =
+ numThreadsNumDims.has_value() && numThreadsNumDims.value();
+ if (hasDimsModifier && numThreads) {
+ return op->emitError("num_threads with dims modifier cannot be used "
+ "together with number of threads");
}
+ if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+ return failure();
+ return success();
+}
- // num_threads dims and number of threads cannot be used together
- if (numThreadsDims.has_value() && numThreads) {
- return emitError(
- "num_threads dims and number of threads cannot be used together");
- }
+LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ if (failed(verifyNumThreadsClause(
+ getOperation(), this->getNumThreadsNumDimsAttr(),
+ this->getNumThreadsDimsValues(), this->getNumThreads())))
+ return failure();
// verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 5072da9e0b9c7..d1587dd4ffd41 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3269,8 +3269,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
// num_threads dims and values are not yet supported
- assert(!opInst.getNumThreadsDims().has_value() &&
- opInst.getNumThreadsValues().empty() &&
+ assert(!opInst.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
@@ -6059,8 +6058,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
})
.Case([&](omp::ParallelOp parallelOp) {
// num_threads dims and values are not yet supported
- assert(!parallelOp.getNumThreadsDims().has_value() &&
- parallelOp.getNumThreadsValues().empty() &&
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
@@ -6185,8 +6183,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
// num_threads dims and values are not yet supported
- assert(!parallelOp.getNumThreadsDims().has_value() &&
- parallelOp.getNumThreadsValues().empty() &&
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 9e2e5722aab9f..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -31,17 +31,17 @@ func.func @num_threads_once(%n : si32) {
// -----
func.func @num_threads_dims_no_values() {
- // expected-error at +1 {{num_threads dims modifier requires values to be specified}}
+ // expected-error at +1 {{dims modifier requires values to be specified}}
"omp.parallel"() ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
return
}
// -----
func.func @num_threads_dims_mismatch(%n : i64) {
- // expected-error at +1 {{num_threads dims(2) specified but 1 values provided}}
+ // expected-error at +1 {{dims(2) specified but 1 values provided}}
omp.parallel num_threads(dims(2): %n : i64) {
omp.terminator
}
@@ -52,10 +52,10 @@ func.func @num_threads_dims_mismatch(%n : i64) {
// -----
func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
- // expected-error at +1 {{num_threads dims and number of threads cannot be used together}}
+ // expected-error at +1 {{num_threads with dims modifier cannot be used together with number of threads}}
"omp.parallel"(%n, %n, %m) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
return
}
>From 84e97fec541d935ec473cf72562861cf4953ceb0 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 19 Dec 2025 12:27:38 +0530
Subject: [PATCH 4/5] Use num_threads_dims_values only
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 4 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 15 ++---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 15 +++--
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 5 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 62 ++++++++-----------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 ++---
mlir/test/Dialect/OpenMP/invalid.mlir | 12 ++--
mlir/test/Dialect/OpenMP/ops.mlir | 10 +--
8 files changed, 66 insertions(+), 73 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index b923e415231d6..abaeaa90f80be 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -516,8 +516,8 @@ bool ClauseProcessor::processNumThreads(
mlir::omp::NumThreadsClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
// OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
- result.numThreads =
- fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.numThreadsDimsValues.push_back(
+ fir::getBase(converter.genExprValue(clause->v, stmtCtx)));
return true;
}
return false;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 989e370870f33..bdbabc292349a 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -99,8 +99,8 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
vars.push_back(ops.numTeamsUpper);
- if (ops.numThreads)
- vars.push_back(ops.numThreads);
+ for (auto numThreads : ops.numThreadsDimsValues)
+ vars.push_back(numThreads);
if (ops.threadLimit)
vars.push_back(ops.threadLimit);
@@ -115,7 +115,8 @@ class HostEvalInfo {
assert(args.size() ==
ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
- (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+ (ops.numTeamsUpper ? 1 : 0) +
+ ops.numThreadsDimsValues.size() +
(ops.threadLimit ? 1 : 0) &&
"invalid block argument list");
int argIndex = 0;
@@ -134,8 +135,8 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
ops.numTeamsUpper = args[argIndex++];
- if (ops.numThreads)
- ops.numThreads = args[argIndex++];
+ for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i)
+ ops.numThreadsDimsValues[i] = args[argIndex++];
if (ops.threadLimit)
ops.threadLimit = args[argIndex++];
@@ -169,13 +170,13 @@ class HostEvalInfo {
/// \returns whether an update was performed. If not, these clauses were not
/// evaluated in the host device.
bool apply(mlir::omp::ParallelOperands &clauseOps) {
- if (!ops.numThreads || parallelApplied) {
+ if (ops.numThreadsDimsValues.empty() || parallelApplied) {
parallelApplied = true;
return false;
}
parallelApplied = true;
- clauseOps.numThreads = ops.numThreads;
+ clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
return true;
}
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index b5c51c5a8dff3..7236baf33a15b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1070,14 +1070,12 @@ class OpenMP_NumThreadsClauseSkip<
extraClassDeclaration> {
let arguments = (ins
ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
- Variadic<AnyInteger>:$num_threads_dims_values,
- Optional<IntLikeType>:$num_threads
+ Variadic<IntLikeType>:$num_threads_dims_values
);
let optAssemblyFormat = [{
`num_threads` `(` custom<NumThreadsClause>(
- $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
- $num_threads, type($num_threads)
+ $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values)
) `)`
}];
@@ -1092,10 +1090,9 @@ class OpenMP_NumThreadsClauseSkip<
- Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
Without dims modifier:
- - Uses `num_threads`
- - If lower bound not specified, it defaults to upper bound value
- - Format: `num_threads(bounds : type)`
- - Example: `num_threads(%ub : i32)`
+ - The number of threads is specified by single value in `num_threads_dims_values`
+ - Format: `num_threads(value : type)`
+ - Example: `num_threads(%n : i32)`
}];
let extraClassDeclaration = [{
@@ -1116,6 +1113,8 @@ class OpenMP_NumThreadsClauseSkip<
::mlir::Value getNumThreadsDimsValue(unsigned index) {
assert(index < getNumThreadsDimsCount() &&
"Num threads dims index out of bounds");
+ if(getNumThreadsDimsValues().empty())
+ return nullptr;
return getNumThreadsDimsValues()[index];
}
}];
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ab7bded7835be..5d75613f9b2b6 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -438,9 +438,11 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
rewriter.eraseOp(reduce);
Value numThreadsVar;
+ SmallVector<Value> numThreadsValues;
if (numThreads > 0) {
numThreadsVar = LLVM::ConstantOp::create(
rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
+ numThreadsValues.push_back(numThreadsVar);
}
// Create the parallel wrapper.
auto ompParallel = omp::ParallelOp::create(
@@ -449,8 +451,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
/* num_threads_num_dims = */ nullptr,
- /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
- /* num_threads = */ numThreadsVar,
+ /* num_threads_dims_values = */ numThreadsValues,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
/* private_needs_barrier = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 4aa97a76c7cb4..9018f5699c433 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2286,7 +2286,8 @@ LogicalResult TargetOp::verifyRegions() {
if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
parallelOp->isAncestor(capturedOp) &&
- hostEvalArg == parallelOp.getNumThreads())
+ llvm::is_contained(parallelOp.getNumThreadsDimsValues(),
+ hostEvalArg))
continue;
return emitOpError()
@@ -2540,7 +2541,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
/*num_threads_dims=*/nullptr,
/*num_threads_values=*/ValueRange(),
- /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
+ /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
/*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
@@ -2551,14 +2552,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- ParallelOp::build(
- builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
- clauses.numThreads, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
- clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms));
+ ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numThreadsNumDims,
+ clauses.numThreadsDimsValues, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privateSyms),
+ clauses.privateNeedsBarrier, clauses.procBindKind,
+ clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms));
}
template <typename OpType>
@@ -2608,13 +2609,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
LogicalResult
verifyNumThreadsClause(Operation *op,
std::optional<IntegerAttr> numThreadsNumDims,
- OperandRange numThreadsDimsValues, Value numThreads) {
- bool hasDimsModifier =
- numThreadsNumDims.has_value() && numThreadsNumDims.value();
- if (hasDimsModifier && numThreads) {
- return op->emitError("num_threads with dims modifier cannot be used "
- "together with number of threads");
- }
+ OperandRange numThreadsDimsValues) {
if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
return failure();
return success();
@@ -2622,9 +2617,9 @@ verifyNumThreadsClause(Operation *op,
LogicalResult ParallelOp::verify() {
// verify num_threads clause restrictions
- if (failed(verifyNumThreadsClause(
- getOperation(), this->getNumThreadsNumDimsAttr(),
- this->getNumThreadsDimsValues(), this->getNumThreads())))
+ if (failed(verifyNumThreadsClause(getOperation(),
+ this->getNumThreadsNumDimsAttr(),
+ this->getNumThreadsDimsValues())))
return failure();
// verify allocate clause restrictions
@@ -4693,33 +4688,28 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
static ParseResult
parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- SmallVectorImpl<Type> &types,
- std::optional<OpAsmParser::UnresolvedOperand> &bounds,
- Type &boundsType) {
+ SmallVectorImpl<Type> &types) {
if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
return success();
}
- OpAsmParser::UnresolvedOperand boundsOperand;
- if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
- parser.parseType(boundsType)) {
+ // Without dims modifier: value : type
+ OpAsmParser::UnresolvedOperand singleValue;
+ Type singleType;
+ if (parser.parseOperand(singleValue) || parser.parseColon() ||
+ parser.parseType(singleType)) {
return failure();
}
- bounds = boundsOperand;
+ values.push_back(singleValue);
+ types.push_back(singleType);
return success();
}
static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
IntegerAttr dimsAttr, OperandRange values,
- TypeRange types, Value bounds,
- Type boundsType) {
- if (!values.empty()) {
- printDimsModifierWithValues(p, dimsAttr, values, types);
- }
- if (bounds) {
- p.printOperand(bounds);
- p << " : " << boundsType;
- }
+ TypeRange types) {
+ // Multidimensional: dims(N): values : type
+ printDimsModifierWithValues(p, dimsAttr, values, types);
}
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d1587dd4ffd41..8cc55d07af0b7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3270,8 +3270,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
llvm::Value *numThreads = nullptr;
// num_threads dims and values are not yet supported
assert(!opInst.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is NYI.");
- if (auto numThreadsVar = opInst.getNumThreads())
+ "Lowering of num_threads with dims modifier is not yet implemented.");
+ if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0))
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
if (auto bind = opInst.getProcBindKind())
@@ -6059,8 +6059,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
.Case([&](omp::ParallelOp parallelOp) {
// num_threads dims and values are not yet supported
assert(!parallelOp.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is NYI.");
- if (parallelOp.getNumThreads() == blockArg)
+ "Lowering of num_threads with dims modifier is not yet "
+ "implemented.");
+ if (parallelOp.getNumThreadsDimsValue(0) == blockArg)
numThreads = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6183,9 +6184,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
// num_threads dims and values are not yet supported
- assert(!parallelOp.hasNumThreadsDimsModifier() &&
- "Lowering of num_threads with dims modifier is NYI.");
- numThreads = parallelOp.getNumThreads();
+ assert(
+ !parallelOp.hasNumThreadsDimsModifier() &&
+ "Lowering of num_threads with dims modifier is not yet implemented.");
+ numThreads = parallelOp.getNumThreadsDimsValue(0);
}
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index db0ddcb415d42..b05bbd4056525 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -34,7 +34,7 @@ func.func @num_threads_dims_no_values() {
// expected-error at +1 {{dims modifier requires values to be specified}}
"omp.parallel"() ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
return
}
@@ -51,11 +51,11 @@ func.func @num_threads_dims_mismatch(%n : i64) {
// -----
-func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
- // expected-error at +1 {{num_threads with dims modifier cannot be used together with number of threads}}
- "omp.parallel"(%n, %n, %m) ({
+func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) {
+ // expected-error at +1 {{dims values can only be specified with dims modifier}}
+ "omp.parallel"(%n, %m) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> ()
return
}
@@ -2797,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
- "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 91cf899349c01..8ca4481779ead 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
"omp.parallel"(%data_var, %data_var, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
"omp.parallel"(%data_var, %data_var, %if_cond) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
// CHECK: omp.parallel
omp.parallel {
>From d927768cd69d7a1164cc2944be23f4360f2f0352 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 Jan 2026 12:07:56 +0530
Subject: [PATCH 5/5] fix adding numThreadsNumDims to ParallelOperands apply
method
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index bdbabc292349a..5ca228e218c37 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -177,6 +177,7 @@ class HostEvalInfo {
parallelApplied = true;
clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
+ clauseOps.numThreadsNumDims = ops.numThreadsNumDims;
return true;
}
More information about the llvm-branch-commits
mailing list