[llvm-branch-commits] [mlir] [OpenMP][MLIR] Add num_threads clause with dims modifier support (PR #171767)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Dec 11 04:57:36 PST 2025
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171767
>From 1c69d29651bb1b73c04cca422454eb7ffffd7c4c Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 11:56:58 +0530
Subject: [PATCH 1/3] [OpenMP][MLIR] Add num_threads clause with dims modifier
support
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 50 +++++++++++-
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 2 +
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 79 +++++++++++++++++--
mlir/test/Dialect/OpenMP/invalid.mlir | 33 +++++++-
mlir/test/Dialect/OpenMP/ops.mlir | 15 ++--
5 files changed, 163 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index e36dc7c246f01..7525b6e4e99f6 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,60 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
+ Variadic<AnyInteger>:$num_threads_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
- `num_threads` `(` $num_threads `:` type($num_threads) `)`
+ `num_threads` `(` custom<NumThreadsClause>(
+ $num_threads_dims, $num_threads_values, type($num_threads_values),
+ $num_threads, type($num_threads)
+ ) `)`
}];
let description = [{
- The optional `num_threads` parameter specifies the number of threads which
- should be used to execute the parallel region.
+ num_threads clause specifies the desired number of threads in the team
+ space formed by the construct on which it appears.
+
+ With dims modifier:
+ - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list)
+ - Specifies upper bounds for each dimension (all must have same type)
+ - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+ - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+ Without dims modifier:
+ - Uses `num_threads`
+ - If lower bound not specified, it defaults to upper bound value
+ - Format: `num_threads(bounds : type)`
+ - Example: `num_threads(%ub : i32)`
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasDimsModifier() {
+ return getNumThreadsDims().has_value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ unsigned getNumDimensions() {
+ if (!hasDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumThreadsDims());
+ }
+
+ /// Returns all dimension values as an operand range
+ ::mlir::OperandRange getDimensionValues() {
+ return getNumThreadsValues();
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumDimensions()
+ ::mlir::Value getDimensionValue(unsigned index) {
+ assert(index < getDimensionValues().size() &&
+ "Dimension index out of bounds");
+ return getDimensionValues()[index];
+ }
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..0d5333ec2e455 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
+ /* num_threads_dims = */ nullptr,
+ /* num_threads_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d4dbf5f5244df..303ab94fbedff 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2533,6 +2533,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ArrayRef<NamedAttribute> attributes) {
ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
/*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+ /*num_threads_dims=*/nullptr,
+ /*num_threads_values=*/ValueRange(),
/*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
/*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
/*proc_bind_kind=*/nullptr,
@@ -2544,13 +2546,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
void ParallelOp::build(OpBuilder &builder, OperationState &state,
const ParallelOperands &clauses) {
MLIRContext *ctx = builder.getContext();
- ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreads, clauses.privateVars,
- makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.procBindKind,
- clauses.reductionMod, clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms));
+ ParallelOp::build(
+ builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+ clauses.numThreads, clauses.privateVars,
+ makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+ clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms));
}
template <typename OpType>
@@ -2597,13 +2600,40 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
}
LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ auto numThreadsDims = getNumThreadsDims();
+ auto numThreadsValues = getNumThreadsValues();
+ auto numThreads = getNumThreads();
+
+ // num_threads with dims modifier
+ if (numThreadsDims.has_value() && numThreadsValues.empty()) {
+ return emitError(
+ "num_threads dims modifier requires values to be specified");
+ }
+
+ if (numThreadsDims.has_value() &&
+ numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
+ return emitError("num_threads dims(")
+ << *numThreadsDims << ") specified but " << numThreadsValues.size()
+ << " values provided";
+ }
+
+ // num_threads dims and number of threads cannot be used together
+ if (numThreadsDims.has_value() && numThreads) {
+ return emitError(
+ "num_threads dims and number of threads cannot be used together");
+ }
+
+ // verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // verify private variables restrictions
if (failed(verifyPrivateVarList(*this)))
return failure();
+ // verify reduction variables restrictions
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
@@ -4647,6 +4677,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types,
+ std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+ Type &boundsType) {
+ if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand boundsOperand;
+ if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+ parser.parseType(boundsType)) {
+ return failure();
+ }
+ bounds = boundsOperand;
+ return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types, Value bounds,
+ Type boundsType) {
+ if (!values.empty()) {
+ printDimsModifierWithValues(p, dimsAttr, values, types);
+ }
+ if (bounds) {
+ p.printOperand(bounds);
+ p << " : " << boundsType;
+ }
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index dd367aba8da27..9e2e5722aab9f 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
// -----
+func.func @num_threads_dims_no_values() {
+ // expected-error at +1 {{num_threads dims modifier requires values to be specified}}
+ "omp.parallel"() ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> ()
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+ // expected-error at +1 {{num_threads dims(2) specified but 1 values provided}}
+ omp.parallel num_threads(dims(2): %n : i64) {
+ omp.terminator
+ }
+
+ return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+ // expected-error at +1 {{num_threads dims and number of threads cannot be used together}}
+ "omp.parallel"(%n, %n, %m) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> ()
+ return
+}
+
+// -----
+
func.func @nowait_not_allowed(%n : memref<i32>) {
// expected-error at +1 {{expected '{' to begin a region}}
omp.parallel nowait {}
@@ -2766,7 +2797,7 @@ func.func @undefined_privatizer(%arg0: index) {
// -----
func.func @undefined_privatizer(%arg0: !llvm.ptr) {
// expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
- "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+ "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
^bb0(%arg2: !llvm.ptr):
omp.terminator
}) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3633a4be1eb62..585c9483c08a9 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
"omp.parallel"(%data_var, %data_var, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
// CHECK: omp.barrier
omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
"omp.parallel"(%data_var, %data_var, %if_cond) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
// test without allocate
// CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
"omp.parallel"(%if_cond, %num_threads) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
// test with multiple parameters for single variadic argument
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
"omp.parallel" (%data_var, %data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
// CHECK: omp.parallel
omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
omp.terminator
}
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+ omp.terminator
+ }
+
// CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
omp.terminator
>From 6946aff41bb7f744d6445d0fc227fb7807ea2191 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 12:11:49 +0530
Subject: [PATCH 2/3] Mark mlir->llvmir translation for num_threads with dims
as NYI
---
.../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 15 ++++++++++++++-
1 file changed, 14 insertions(+), 1 deletion(-)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 00f782e87d5af..8d3d0ccb665bd 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2879,6 +2879,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
if (auto ifVar = opInst.getIfExpr())
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
+ // num_threads dims and values are not yet supported
+ assert(!opInst.getNumThreadsDims().has_value() &&
+ opInst.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -5604,6 +5608,10 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm_unreachable("unsupported host_eval use");
})
.Case([&](omp::ParallelOp parallelOp) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.getNumThreadsDims().has_value() &&
+ parallelOp.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
else
@@ -5724,8 +5732,13 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
threadLimit = teamsOp.getThreadLimit();
}
- if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+ if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+ // num_threads dims and values are not yet supported
+ assert(!parallelOp.getNumThreadsDims().has_value() &&
+ parallelOp.getNumThreadsValues().empty() &&
+ "Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
+ }
}
// Handle clauses impacting the number of teams.
>From 33dcfd92bea8181da414b766101847338ee3b963 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 17:37:52 +0530
Subject: [PATCH 3/3] few more fixes
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 33 ++++++--------
.../Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 4 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 44 +++++++++----------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 9 ++--
mlir/test/Dialect/OpenMP/invalid.mlir | 10 ++---
5 files changed, 45 insertions(+), 55 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 7525b6e4e99f6..09c1d4a8a5866 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,14 +1069,14 @@ class OpenMP_NumThreadsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
- ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
- Variadic<AnyInteger>:$num_threads_values,
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+ Variadic<AnyInteger>:$num_threads_dims_values,
Optional<IntLikeType>:$num_threads
);
let optAssemblyFormat = [{
`num_threads` `(` custom<NumThreadsClause>(
- $num_threads_dims, $num_threads_values, type($num_threads_values),
+ $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
$num_threads, type($num_threads)
) `)`
}];
@@ -1086,7 +1086,7 @@ class OpenMP_NumThreadsClauseSkip<
space formed by the construct on which it appears.
With dims modifier:
- - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list)
+ - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
- Specifies upper bounds for each dimension (all must have same type)
- Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
- Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
@@ -1100,28 +1100,23 @@ class OpenMP_NumThreadsClauseSkip<
let extraClassDeclaration = [{
/// Returns true if the dims modifier is explicitly present
- bool hasDimsModifier() {
- return getNumThreadsDims().has_value();
+ bool hasNumThreadsDimsModifier() {
+ return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
}
/// Returns the number of dimensions specified by dims modifier
- unsigned getNumDimensions() {
- if (!hasDimsModifier())
+ unsigned getNumThreadsDimsCount() {
+ if (!hasNumThreadsDimsModifier())
return 1;
- return static_cast<unsigned>(*getNumThreadsDims());
- }
-
- /// Returns all dimension values as an operand range
- ::mlir::OperandRange getDimensionValues() {
- return getNumThreadsValues();
+ return static_cast<unsigned>(*getNumThreadsNumDims());
}
/// Returns the value for a specific dimension index
- /// Index must be less than getNumDimensions()
- ::mlir::Value getDimensionValue(unsigned index) {
- assert(index < getDimensionValues().size() &&
- "Dimension index out of bounds");
- return getDimensionValues()[index];
+ /// Index must be less than getNumThreadsDimsCount()
+ ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+ assert(index < getNumThreadsDimsCount() &&
+ "Num threads dims index out of bounds");
+ return getNumThreadsDimsValues()[index];
}
}];
}
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 0d5333ec2e455..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,8 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/* allocate_vars = */ llvm::SmallVector<Value>{},
/* allocator_vars = */ llvm::SmallVector<Value>{},
/* if_expr = */ Value{},
- /* num_threads_dims = */ nullptr,
- /* num_threads_values = */ llvm::SmallVector<Value>{},
+ /* num_threads_num_dims = */ nullptr,
+ /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
/* num_threads = */ numThreadsVar,
/* private_vars = */ ValueRange(),
/* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 303ab94fbedff..a9ed0274cd21c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2548,7 +2548,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
ParallelOp::build(
builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+ clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
clauses.numThreads, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
@@ -2599,30 +2599,28 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
return success();
}
-LogicalResult ParallelOp::verify() {
- // verify num_threads clause restrictions
- auto numThreadsDims = getNumThreadsDims();
- auto numThreadsValues = getNumThreadsValues();
- auto numThreads = getNumThreads();
-
- // num_threads with dims modifier
- if (numThreadsDims.has_value() && numThreadsValues.empty()) {
- return emitError(
- "num_threads dims modifier requires values to be specified");
- }
-
- if (numThreadsDims.has_value() &&
- numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
- return emitError("num_threads dims(")
- << *numThreadsDims << ") specified but " << numThreadsValues.size()
- << " values provided";
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+ std::optional<IntegerAttr> numThreadsNumDims,
+ OperandRange numThreadsDimsValues, Value numThreads) {
+ bool hasDimsModifier =
+ numThreadsNumDims.has_value() && numThreadsNumDims.value();
+ if (hasDimsModifier && numThreads) {
+ return op->emitError("num_threads with dims modifier cannot be used "
+ "together with number of threads");
}
+ if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+ return failure();
+ return success();
+}
- // num_threads dims and number of threads cannot be used together
- if (numThreadsDims.has_value() && numThreads) {
- return emitError(
- "num_threads dims and number of threads cannot be used together");
- }
+LogicalResult ParallelOp::verify() {
+ // verify num_threads clause restrictions
+ if (failed(verifyNumThreadsClause(
+ getOperation(), this->getNumThreadsNumDimsAttr(),
+ this->getNumThreadsDimsValues(), this->getNumThreads())))
+ return failure();
// verify allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8d3d0ccb665bd..2bfb9fb2211c4 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2880,8 +2880,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
ifCond = moduleTranslation.lookupValue(ifVar);
llvm::Value *numThreads = nullptr;
// num_threads dims and values are not yet supported
- assert(!opInst.getNumThreadsDims().has_value() &&
- opInst.getNumThreadsValues().empty() &&
+ assert(!opInst.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
if (auto numThreadsVar = opInst.getNumThreads())
numThreads = moduleTranslation.lookupValue(numThreadsVar);
@@ -5609,8 +5608,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
})
.Case([&](omp::ParallelOp parallelOp) {
// num_threads dims and values are not yet supported
- assert(!parallelOp.getNumThreadsDims().has_value() &&
- parallelOp.getNumThreadsValues().empty() &&
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
if (parallelOp.getNumThreads() == blockArg)
numThreads = hostEvalVar;
@@ -5734,8 +5732,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
// num_threads dims and values are not yet supported
- assert(!parallelOp.getNumThreadsDims().has_value() &&
- parallelOp.getNumThreadsValues().empty() &&
+ assert(!parallelOp.hasNumThreadsDimsModifier() &&
"Lowering of num_threads with dims modifier is NYI.");
numThreads = parallelOp.getNumThreads();
}
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 9e2e5722aab9f..db0ddcb415d42 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -31,17 +31,17 @@ func.func @num_threads_once(%n : si32) {
// -----
func.func @num_threads_dims_no_values() {
- // expected-error at +1 {{num_threads dims modifier requires values to be specified}}
+ // expected-error at +1 {{dims modifier requires values to be specified}}
"omp.parallel"() ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
return
}
// -----
func.func @num_threads_dims_mismatch(%n : i64) {
- // expected-error at +1 {{num_threads dims(2) specified but 1 values provided}}
+ // expected-error at +1 {{dims(2) specified but 1 values provided}}
omp.parallel num_threads(dims(2): %n : i64) {
omp.terminator
}
@@ -52,10 +52,10 @@ func.func @num_threads_dims_mismatch(%n : i64) {
// -----
func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
- // expected-error at +1 {{num_threads dims and number of threads cannot be used together}}
+ // expected-error at +1 {{num_threads with dims modifier cannot be used together with number of threads}}
"omp.parallel"(%n, %n, %m) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
return
}
More information about the llvm-branch-commits
mailing list