[llvm-branch-commits] [flang] [mlir] [OpenMP][MLIR] Add thread_limit with dims modifier support (PR #171825)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Fri Jan 16 20:53:06 PST 2026
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171825
>From e3e8ed5a2bf6d33716efb6741d03891bfe3f6947 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 16 Jan 2026 08:48:30 +0530
Subject: [PATCH 1/9] Update num_teams to have just the list and no dims(N)
syntax
---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 2 +-
mlir/test/Dialect/OpenMP/ops.mlir | 4 ++--
2 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 67ff9023a38da..e5b98024dbed1 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4588,7 +4588,7 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
p << " : " << upperBoundType;
} else {
// Upper only: to upper : type
- p << " to ";
+ p << "to ";
p.printOperand(upperBound);
p << " : " << upperBoundType;
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 89c7e5fd48bd9..d28f31c8328b2 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1103,7 +1103,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
omp.terminator
}
- // CHECK: omp.teams num_teams( to %{{.+}} : i32)
+ // CHECK: omp.teams num_teams(to %{{.+}} : i32)
omp.teams num_teams(to %ub : i32) {
// CHECK: omp.terminator
omp.terminator
@@ -3084,7 +3084,7 @@ func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?
func.func @omp_target_host_eval(%x : i32) {
// CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
- // CHECK: omp.teams num_teams( to %[[HOST_ARG]] : i32)
+ // CHECK: omp.teams num_teams(to %[[HOST_ARG]] : i32)
// CHECK-SAME: thread_limit(%[[HOST_ARG]] : i32)
omp.target host_eval(%x -> %arg0 : i32) {
omp.teams num_teams( to %arg0 : i32) thread_limit(%arg0 : i32) {
>From 858f03936cb9423012d8a859cf3b016d67355f1f Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 13:35:05 +0530
Subject: [PATCH 2/9] [OpenMP][MLIR] Add thread_limit with dims modifier
support
---
.../Optimizer/OpenMP/LowerWorkdistribute.cpp | 16 +-
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 29 +++-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 69 ++++++++-
mlir/test/Dialect/OpenMP/invalid.mlir | 139 +++++++++++++++++-
mlir/test/Dialect/OpenMP/ops.mlir | 8 +-
5 files changed, 249 insertions(+), 12 deletions(-)
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 7b61539984232..a3b9e5c76bdd2 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -766,6 +766,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+ targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
newTargetOp.getRegion().begin());
@@ -1485,8 +1486,9 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
- targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
- targetOp.getPrivateMapsAttr());
+ targetOp.getPrivateNeedsBarrierAttr(),
+ targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+ targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
auto *preTargetBlock = rewriter.createBlock(
&preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
IRMapping preMapping;
@@ -1575,8 +1577,9 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
- targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
- targetOp.getPrivateMapsAttr());
+ targetOp.getPrivateNeedsBarrierAttr(),
+ targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+ targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
auto *isolatedTargetBlock =
rewriter.createBlock(&isolatedTargetOp.getRegion(),
isolatedTargetOp.getRegion().begin(), {}, {});
@@ -1655,8 +1658,9 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
- targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
- targetOp.getPrivateMapsAttr());
+ targetOp.getPrivateNeedsBarrierAttr(),
+ targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+ targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
// Create the block for postTargetOp
auto *postTargetBlock = rewriter.createBlock(
&postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index d4640f254ed1f..d2ebd29229e84 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1452,16 +1452,43 @@ class OpenMP_ThreadLimitClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims,
+ Variadic<AnyInteger>:$thread_limit_dims_values,
Optional<AnyInteger>:$thread_limit
);
let optAssemblyFormat = [{
- `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
+ `thread_limit` `(` custom<ThreadLimitClause>(
+ $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values),
+ $thread_limit, type($thread_limit)
+ ) `)`
}];
let description = [{
The optional `thread_limit` specifies the limit on the number of threads.
}];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasThreadLimitDimsModifier() {
+ return getThreadLimitNumDims().has_value() && getThreadLimitNumDims().value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ unsigned getThreadLimitDimsCount() {
+ if (!hasThreadLimitDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getThreadLimitNumDims());
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getThreadLimitDimsCount()
+ ::mlir::Value getThreadLimitDimensionValue(unsigned index) {
+ assert(index < getThreadLimitDimsCount() &&
+ "Thread limit dims index out of bounds");
+ return getThreadLimitDimsValues()[index];
+ }
+ }];
}
def OpenMP_ThreadLimitClause : OpenMP_ThreadLimitClauseSkip<>;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e5b98024dbed1..e83419492d28e 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2210,10 +2210,30 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
clauses.mapVars, clauses.nowait, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.threadLimit,
+ clauses.privateNeedsBarrier, clauses.threadLimitNumDims,
+ clauses.threadLimitDimsValues, clauses.threadLimit,
/*private_maps=*/nullptr);
}
+// helper for thread_limit clause restrictions
+static LogicalResult
+verifyThreadLimitClause(Operation *op,
+ std::optional<IntegerAttr> threadLimitNumDims,
+ OperandRange threadLimitDimsValues, Value threadLimit) {
+ bool hasDimsModifier =
+ threadLimitNumDims.has_value() && threadLimitNumDims.value();
+
+ if (hasDimsModifier && threadLimit) {
+ return op->emitError("thread_limit with dims modifier cannot be used "
+ "together with number of threads");
+ }
+
+ if (failed(verifyDimsModifier(op, threadLimitNumDims, threadLimitDimsValues)))
+ return failure();
+
+ return success();
+}
+
LogicalResult TargetOp::verify() {
if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
return failure();
@@ -2225,6 +2245,11 @@ LogicalResult TargetOp::verify() {
if (failed(verifyMapClause(*this, getMapVars())))
return failure();
+ if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
+ getThreadLimitDimsValues(),
+ getThreadLimit())))
+ return failure();
+
return verifyPrivateVarsMapping(*this);
}
@@ -2687,6 +2712,12 @@ LogicalResult TeamsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
+ // Check for thread_limit clause restrictions
+ if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
+ getThreadLimitDimsValues(),
+ getThreadLimit())))
+ return failure();
+
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
@@ -4595,6 +4626,42 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for thread_limit clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseThreadLimitClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types,
+ std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+ Type &boundsType) {
+ if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+ return success();
+ }
+
+ OpAsmParser::UnresolvedOperand boundsOperand;
+ if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+ parser.parseType(boundsType)) {
+ return failure();
+ }
+ bounds = boundsOperand;
+ return success();
+}
+
+static void printThreadLimitClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types, Value bounds,
+ Type boundsType) {
+ if (!values.empty()) {
+ // Multidimensional: dims(N): values : type
+ printDimsModifierWithValues(p, dimsAttr, values, types);
+ } else if (bounds) {
+ // Both bounds: bounds : type
+ p.printOperand(bounds);
+ p << " : " << boundsType;
+ }
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index bb882db73cbab..e841e65d36292 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.teams" (%data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
omp.terminator
}
return
@@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0,0>} : (i32) -> ()
omp.terminator
}
return
@@ -1489,6 +1489,139 @@ func.func @omp_teams_num_teams2(%lb : i32, %ub : i16) {
// -----
+func.func @omp_teams_thread_limit_dims_mismatch() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ // expected-error @below {{dims(3) specified but 2 values provided}}
+ "omp.teams" (%v0, %v1) ({
+ omp.terminator
+ }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_thread_limit_dims_with_scalar() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ %tl = arith.constant 4 : i32
+ // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}}
+ "omp.teams" (%v0, %v1, %tl) ({
+ omp.terminator
+ }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_thread_limit_dims_no_values() {
+ omp.target {
+ // expected-error @below {{dims modifier requires values to be specified}}
+ "omp.teams" () ({
+ omp.terminator
+ }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_thread_limit_values_without_dims() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ // expected-error @below {{dims values can only be specified with dims modifier}}
+ "omp.teams" (%v0, %v1) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_thread_limit_dims_type_mismatch() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i64
+ // expected-error @below {{dims modifier requires all values to have the same type}}
+ "omp.teams" (%v0, %v1) ({
+ omp.terminator
+ }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_target_thread_limit_dims_mismatch() {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ // expected-error @below {{dims(3) specified but 2 values provided}}
+ "omp.target" (%v0, %v1) ({
+ omp.terminator
+ }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+ return
+}
+
+// -----
+
+func.func @omp_target_thread_limit_dims_with_scalar() {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ %tl = arith.constant 4 : i32
+ // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}}
+ "omp.target" (%v0, %v1, %tl) ({
+ omp.terminator
+ }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> ()
+ return
+}
+
+// -----
+
+func.func @omp_target_thread_limit_dims_no_values() {
+ // expected-error @below {{dims modifier requires values to be specified}}
+ "omp.target" () ({
+ omp.terminator
+ }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0,0>} : () -> ()
+ return
+}
+
+// -----
+
+func.func @omp_target_thread_limit_values_without_dims() {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ // expected-error @below {{dims values can only be specified with dims modifier}}
+ "omp.target" (%v0, %v1) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+ return
+}
+
+// -----
+
+func.func @omp_target_thread_limit_dims_type_mismatch() {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i64
+ // expected-error @below {{dims modifier requires all values to have the same type}}
+ "omp.target" (%v0, %v1) ({
+ omp.terminator
+ }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> ()
+ return
+}
+
+// -----
+
func.func @omp_sections(%data_var : memref<i32>) -> () {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.sections" (%data_var) ({
@@ -2475,7 +2608,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
"omp.target"(%data_var) ({
"omp.terminator"() : () -> ()
- }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+ }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
"func.return"() : () -> ()
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index d28f31c8328b2..39ddc5bfa4e50 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -824,7 +824,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic
"omp.target"(%device, %if_cond, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1136,6 +1136,12 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
omp.terminator
}
+ // CHECK: omp.teams thread_limit(dims(2): %{{.*}}, %{{.*}} : i32)
+ omp.teams thread_limit(dims(2): %lb, %ub : i32) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+
// Test reduction.
%c1 = arith.constant 1 : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
>From 1f1eac98b3a73d389a805ffe077a10fb08943599 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 18:35:55 +0530
Subject: [PATCH 3/9] update thread_limit description
---
mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td | 13 ++++++++++++-
1 file changed, 12 insertions(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index d2ebd29229e84..c8d8d0003deef 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1465,7 +1465,18 @@ class OpenMP_ThreadLimitClauseSkip<
}];
let description = [{
- The optional `thread_limit` specifies the limit on the number of threads.
+ The `thread_limit` clause specifies the limit on the number of threads.
+
+ With dims modifier:
+ - The number of dimensions is specified by the `thread_limit_num_dims` attribute.
+ - The values for each dimension are specified by the `thread_limit_dims_values` attribute.
+ - Format: `thread_limit(dims(N): values : type)`
+ - Example: `thread_limit(dims(2): %n, %m : i64)`
+
+ Without dims modifier:
+ - The number of threads is specified by the `thread_limit`.
+ - Format: `thread_limit(number_of_threads : type)`
+ - Example: `thread_limit(%n : i64)`
}];
let extraClassDeclaration = [{
>From 12da9b0afc48b841481ec35fd2de72ffbb451459 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 17 Dec 2025 09:06:36 +0530
Subject: [PATCH 4/9] Remove separate thread_limit argument from clause
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 3 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 16 +++---
.../Optimizer/OpenMP/LowerWorkdistribute.cpp | 8 +--
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 16 +++---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 54 +++++++------------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 11 ++--
mlir/test/Dialect/OpenMP/invalid.mlir | 28 +++++-----
mlir/test/Dialect/OpenMP/ops.mlir | 2 +-
8 files changed, 63 insertions(+), 75 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index b923e415231d6..18bab01d94365 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -668,8 +668,9 @@ bool ClauseProcessor::processThreadLimit(
lower::StatementContext &stmtCtx,
mlir::omp::ThreadLimitClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
- result.threadLimit =
+ mlir::Value threadLimitVal =
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+ result.threadLimitDimsValues.push_back(threadLimitVal);
return true;
}
return false;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 989e370870f33..1021742b87b2f 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -102,8 +102,9 @@ class HostEvalInfo {
if (ops.numThreads)
vars.push_back(ops.numThreads);
- if (ops.threadLimit)
- vars.push_back(ops.threadLimit);
+ // Old spec: single value in threadLimitDimsValues
+ for (mlir::Value val : ops.threadLimitDimsValues)
+ vars.push_back(val);
}
/// Update \c ops, replacing all values with the corresponding block argument
@@ -116,7 +117,7 @@ class HostEvalInfo {
ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
(ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
- (ops.threadLimit ? 1 : 0) &&
+ ops.threadLimitDimsValues.size() &&
"invalid block argument list");
int argIndex = 0;
for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
@@ -137,8 +138,8 @@ class HostEvalInfo {
if (ops.numThreads)
ops.numThreads = args[argIndex++];
- if (ops.threadLimit)
- ops.threadLimit = args[argIndex++];
+ for (size_t i = 0; i < ops.threadLimitDimsValues.size(); ++i)
+ ops.threadLimitDimsValues[i] = args[argIndex++];
}
/// Update \p clauseOps and \p ivOut with the corresponding host-evaluated
@@ -185,12 +186,13 @@ class HostEvalInfo {
/// \returns whether an update was performed. If not, these clauses were not
/// evaluated in the host device.
bool apply(mlir::omp::TeamsOperands &clauseOps) {
- if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit)
+ if (!ops.numTeamsLower && !ops.numTeamsUpper &&
+ ops.threadLimitDimsValues.empty())
return false;
clauseOps.numTeamsLower = ops.numTeamsLower;
clauseOps.numTeamsUpper = ops.numTeamsUpper;
- clauseOps.threadLimit = ops.threadLimit;
+ clauseOps.threadLimitDimsValues = ops.threadLimitDimsValues;
return true;
}
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index a3b9e5c76bdd2..4d3fec3b0710f 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -767,7 +767,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
- targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+ targetOp.getPrivateMapsAttr());
rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
newTargetOp.getRegion().begin());
rewriter.replaceOp(targetOp, targetDataOp);
@@ -1488,7 +1488,7 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
targetOp.getPrivateNeedsBarrierAttr(),
targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
- targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+ targetOp.getPrivateMapsAttr());
auto *preTargetBlock = rewriter.createBlock(
&preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
IRMapping preMapping;
@@ -1579,7 +1579,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
targetOp.getPrivateNeedsBarrierAttr(),
targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
- targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+ targetOp.getPrivateMapsAttr());
auto *isolatedTargetBlock =
rewriter.createBlock(&isolatedTargetOp.getRegion(),
isolatedTargetOp.getRegion().begin(), {}, {});
@@ -1660,7 +1660,7 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
targetOp.getPrivateNeedsBarrierAttr(),
targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
- targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+ targetOp.getPrivateMapsAttr());
// Create the block for postTargetOp
auto *postTargetBlock = rewriter.createBlock(
&postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index c8d8d0003deef..de39a94c17a6e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1453,14 +1453,12 @@ class OpenMP_ThreadLimitClauseSkip<
extraClassDeclaration> {
let arguments = (ins
ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims,
- Variadic<AnyInteger>:$thread_limit_dims_values,
- Optional<AnyInteger>:$thread_limit
+ Variadic<AnyInteger>:$thread_limit_dims_values
);
let optAssemblyFormat = [{
`thread_limit` `(` custom<ThreadLimitClause>(
- $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values),
- $thread_limit, type($thread_limit)
+ $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values)
) `)`
}];
@@ -1468,14 +1466,14 @@ class OpenMP_ThreadLimitClauseSkip<
The `thread_limit` clause specifies the limit on the number of threads.
With dims modifier:
- - The number of dimensions is specified by the `thread_limit_num_dims` attribute.
- - The values for each dimension are specified by the `thread_limit_dims_values` attribute.
+ - The number of dimensions is specified by the `thread_limit_num_dims`.
+ - The values for each dimension are specified by the `thread_limit_dims_values`.
- Format: `thread_limit(dims(N): values : type)`
- Example: `thread_limit(dims(2): %n, %m : i64)`
Without dims modifier:
- - The number of threads is specified by the `thread_limit`.
- - Format: `thread_limit(number_of_threads : type)`
+ - The number of threads is specified by the single value in `thread_limit_dims_values`.
+ - Format: `thread_limit(value : type)`
- Example: `thread_limit(%n : i64)`
}];
@@ -1497,6 +1495,8 @@ class OpenMP_ThreadLimitClauseSkip<
::mlir::Value getThreadLimitDimensionValue(unsigned index) {
assert(index < getThreadLimitDimsCount() &&
"Thread limit dims index out of bounds");
+ if (getThreadLimitDimsValues().empty())
+ return nullptr;
return getThreadLimitDimsValues()[index];
}
}];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index e83419492d28e..4f96b1c079670 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2211,7 +2211,7 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
clauses.mapVars, clauses.nowait, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
clauses.privateNeedsBarrier, clauses.threadLimitNumDims,
- clauses.threadLimitDimsValues, clauses.threadLimit,
+ clauses.threadLimitDimsValues,
/*private_maps=*/nullptr);
}
@@ -2219,15 +2219,7 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
static LogicalResult
verifyThreadLimitClause(Operation *op,
std::optional<IntegerAttr> threadLimitNumDims,
- OperandRange threadLimitDimsValues, Value threadLimit) {
- bool hasDimsModifier =
- threadLimitNumDims.has_value() && threadLimitNumDims.value();
-
- if (hasDimsModifier && threadLimit) {
- return op->emitError("thread_limit with dims modifier cannot be used "
- "together with number of threads");
- }
-
+ OperandRange threadLimitDimsValues) {
if (failed(verifyDimsModifier(op, threadLimitNumDims, threadLimitDimsValues)))
return failure();
@@ -2246,8 +2238,7 @@ LogicalResult TargetOp::verify() {
return failure();
if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
- getThreadLimitDimsValues(),
- getThreadLimit())))
+ getThreadLimitDimsValues())))
return failure();
return verifyPrivateVarsMapping(*this);
@@ -2265,10 +2256,9 @@ LogicalResult TargetOp::verifyRegions() {
cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
for (Operation *user : hostEvalArg.getUsers()) {
if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
- if (llvm::is_contained({teamsOp.getNumTeamsLower(),
- teamsOp.getNumTeamsUpper(),
- teamsOp.getThreadLimit()},
- hostEvalArg))
+ if (teamsOp.getNumTeamsLower() == hostEvalArg ||
+ teamsOp.getNumTeamsUpper() == hostEvalArg ||
+ llvm::is_contained(teamsOp.getThreadLimitDimsValues(), hostEvalArg))
continue;
return emitOpError() << "host_eval argument only legal as 'num_teams' "
@@ -2714,8 +2704,7 @@ LogicalResult TeamsOp::verify() {
// Check for thread_limit clause restrictions
if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
- getThreadLimitDimsValues(),
- getThreadLimit())))
+ getThreadLimitDimsValues())))
return failure();
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
@@ -4632,34 +4621,29 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
static ParseResult
parseThreadLimitClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- SmallVectorImpl<Type> &types,
- std::optional<OpAsmParser::UnresolvedOperand> &bounds,
- Type &boundsType) {
+ SmallVectorImpl<Type> &types) {
+ // Try parsing with dims modifier: dims(N): values : type
if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
return success();
}
- OpAsmParser::UnresolvedOperand boundsOperand;
- if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
- parser.parseType(boundsType)) {
+ // Without dims modifier: value : type
+ OpAsmParser::UnresolvedOperand singleValue;
+ Type singleType;
+ if (parser.parseOperand(singleValue) || parser.parseColon() ||
+ parser.parseType(singleType)) {
return failure();
}
- bounds = boundsOperand;
+ values.push_back(singleValue);
+ types.push_back(singleType);
return success();
}
static void printThreadLimitClause(OpAsmPrinter &p, Operation *op,
IntegerAttr dimsAttr, OperandRange values,
- TypeRange types, Value bounds,
- Type boundsType) {
- if (!values.empty()) {
- // Multidimensional: dims(N): values : type
- printDimsModifierWithValues(p, dimsAttr, values, types);
- } else if (bounds) {
- // Both bounds: bounds : type
- p.printOperand(bounds);
- p << " : " << boundsType;
- }
+ TypeRange types) {
+ // Multidimensional: dims(N): values : type
+ printDimsModifierWithValues(p, dimsAttr, values, types);
}
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8a3a990e5a3fd..8cb6d4b21b8b2 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2075,7 +2075,7 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
llvm::Value *threadLimit = nullptr;
- if (Value threadLimitVar = op.getThreadLimit())
+ if (Value threadLimitVar = op.getThreadLimitDimensionValue(0))
threadLimit = moduleTranslation.lookupValue(threadLimitVar);
llvm::Value *ifExpr = nullptr;
@@ -6044,7 +6044,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
numTeamsLower = hostEvalVar;
else if (teamsOp.getNumTeamsUpper() == blockArg)
numTeamsUpper = hostEvalVar;
- else if (teamsOp.getThreadLimit() == blockArg)
+ else if (teamsOp.getThreadLimitDimensionValue(0) == blockArg)
threadLimit = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6164,7 +6164,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
numTeamsLower = teamsOp.getNumTeamsLower();
numTeamsUpper = teamsOp.getNumTeamsUpper();
- threadLimit = teamsOp.getThreadLimit();
+ threadLimit = teamsOp.getThreadLimitDimensionValue(0);
}
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
@@ -6209,7 +6209,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// Extract 'thread_limit' clause from 'target' and 'teams' directives.
int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
- setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
+ setMaxValueFromClause(targetOp.getThreadLimitDimensionValue(0),
+ targetThreadLimitVal);
setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
// Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
@@ -6288,7 +6289,7 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
// TODO: Handle constant 'if' clauses.
- if (Value targetThreadLimit = targetOp.getThreadLimit())
+ if (Value targetThreadLimit = targetOp.getThreadLimitDimensionValue(0))
attrs.TargetThreadLimit.front() =
moduleTranslation.lookupValue(targetThreadLimit);
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index e841e65d36292..649c0fde35ee6 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.teams" (%data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
omp.terminator
}
return
@@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0,0>} : (i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
omp.terminator
}
return
@@ -1496,7 +1496,7 @@ func.func @omp_teams_thread_limit_dims_mismatch() {
// expected-error @below {{dims(3) specified but 2 values provided}}
"omp.teams" (%v0, %v1) ({
omp.terminator
- }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+ }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
omp.terminator
}
return
@@ -1509,10 +1509,10 @@ func.func @omp_teams_thread_limit_dims_with_scalar() {
%v0 = arith.constant 1 : i32
%v1 = arith.constant 2 : i32
%tl = arith.constant 4 : i32
- // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}}
+ // expected-error @below {{dims(2) specified but 3 values provided}}
"omp.teams" (%v0, %v1, %tl) ({
omp.terminator
- }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> ()
+ }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> ()
omp.terminator
}
return
@@ -1540,7 +1540,7 @@ func.func @omp_teams_thread_limit_values_without_dims() {
// expected-error @below {{dims values can only be specified with dims modifier}}
"omp.teams" (%v0, %v1) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
omp.terminator
}
return
@@ -1555,7 +1555,7 @@ func.func @omp_teams_thread_limit_dims_type_mismatch() {
// expected-error @below {{dims modifier requires all values to have the same type}}
"omp.teams" (%v0, %v1) ({
omp.terminator
- }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> ()
+ }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i64) -> ()
omp.terminator
}
return
@@ -1569,7 +1569,7 @@ func.func @omp_target_thread_limit_dims_mismatch() {
// expected-error @below {{dims(3) specified but 2 values provided}}
"omp.target" (%v0, %v1) ({
omp.terminator
- }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+ }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
return
}
@@ -1579,10 +1579,10 @@ func.func @omp_target_thread_limit_dims_with_scalar() {
%v0 = arith.constant 1 : i32
%v1 = arith.constant 2 : i32
%tl = arith.constant 4 : i32
- // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}}
+ // expected-error @below {{dims(2) specified but 3 values provided}}
"omp.target" (%v0, %v1, %tl) ({
omp.terminator
- }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> ()
+ }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> ()
return
}
@@ -1592,7 +1592,7 @@ func.func @omp_target_thread_limit_dims_no_values() {
// expected-error @below {{dims modifier requires values to be specified}}
"omp.target" () ({
omp.terminator
- }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0,0>} : () -> ()
+ }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0>} : () -> ()
return
}
@@ -1604,7 +1604,7 @@ func.func @omp_target_thread_limit_values_without_dims() {
// expected-error @below {{dims values can only be specified with dims modifier}}
"omp.target" (%v0, %v1) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
return
}
@@ -1616,7 +1616,7 @@ func.func @omp_target_thread_limit_dims_type_mismatch() {
// expected-error @below {{dims modifier requires all values to have the same type}}
"omp.target" (%v0, %v1) ({
omp.terminator
- }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> ()
+ }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i64) -> ()
return
}
@@ -2608,7 +2608,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
// expected-error @below {{op expected as many depend values as depend variables}}
"omp.target"(%data_var) ({
"omp.terminator"() : () -> ()
- }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+ }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
"func.return"() : () -> ()
}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 39ddc5bfa4e50..51b1eed766ac3 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -824,7 +824,7 @@ func.func @omp_target(%if_cond : i1, %device : si32, %num_threads : i32, %devic
"omp.target"(%device, %if_cond, %num_threads) ({
// CHECK: omp.terminator
omp.terminator
- }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+ }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
// Test with optional map clause.
// CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>) map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
>From 243d8dc86b3602cacc18e754fb41a935f502cdd4 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 19 Dec 2025 10:11:01 +0530
Subject: [PATCH 5/9] comments fixes
---
mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index de39a94c17a6e..7d2810018e45f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1453,7 +1453,7 @@ class OpenMP_ThreadLimitClauseSkip<
extraClassDeclaration> {
let arguments = (ins
ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims,
- Variadic<AnyInteger>:$thread_limit_dims_values
+ Variadic<IntLikeType>:$thread_limit_dims_values
);
let optAssemblyFormat = [{
>From 7876b6b7015f09102b57970a6d04a909de953dea Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 19 Dec 2025 15:07:17 +0530
Subject: [PATCH 6/9] fix comment
---
.../Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8cb6d4b21b8b2..7b2426e860d8d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6040,6 +6040,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
for (Operation *user : blockArg.getUsers()) {
llvm::TypeSwitch<Operation *>(user)
.Case([&](omp::TeamsOp teamsOp) {
+ // num_teams dims and values are not yet supported
if (teamsOp.getNumTeamsLower() == blockArg)
numTeamsLower = hostEvalVar;
else if (teamsOp.getNumTeamsUpper() == blockArg)
>From f40deb8a3b3752d28ffcc3e6cb983a9ca2ac66e9 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 Jan 2026 12:19:24 +0530
Subject: [PATCH 7/9] [Flang] Add missing threadLimitNumDims in TeamsOperands
apply method
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 1 +
1 file changed, 1 insertion(+)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 1021742b87b2f..f05c918ed3cde 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -193,6 +193,7 @@ class HostEvalInfo {
clauseOps.numTeamsLower = ops.numTeamsLower;
clauseOps.numTeamsUpper = ops.numTeamsUpper;
clauseOps.threadLimitDimsValues = ops.threadLimitDimsValues;
+ clauseOps.threadLimitNumDims = ops.threadLimitNumDims;
return true;
}
>From f72662402e2d1f3d96476b305dea242550c7ea78 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 16 Jan 2026 09:55:34 +0530
Subject: [PATCH 8/9] remove dims(N) syntax and just use list for dims vals
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 2 +-
flang/lib/Lower/OpenMP/OpenMP.cpp | 15 +-
.../Optimizer/OpenMP/LowerWorkdistribute.cpp | 12 +-
flang/test/Lower/OpenMP/teams.f90 | 2 +-
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 48 +++----
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 56 ++------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 24 +++-
mlir/test/Dialect/OpenMP/invalid.mlir | 133 ------------------
mlir/test/Dialect/OpenMP/ops.mlir | 15 +-
mlir/test/Target/LLVMIR/openmp-todo.mlir | 11 ++
10 files changed, 87 insertions(+), 231 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 18bab01d94365..8083b1b10aee7 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -670,7 +670,7 @@ bool ClauseProcessor::processThreadLimit(
if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
mlir::Value threadLimitVal =
fir::getBase(converter.genExprValue(clause->v, stmtCtx));
- result.threadLimitDimsValues.push_back(threadLimitVal);
+ result.threadLimitVals.push_back(threadLimitVal);
return true;
}
return false;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index f05c918ed3cde..b8068ee09cf81 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -102,8 +102,7 @@ class HostEvalInfo {
if (ops.numThreads)
vars.push_back(ops.numThreads);
- // Old spec: single value in threadLimitDimsValues
- for (mlir::Value val : ops.threadLimitDimsValues)
+ for (mlir::Value val : ops.threadLimitVals)
vars.push_back(val);
}
@@ -117,7 +116,7 @@ class HostEvalInfo {
ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
(ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
- ops.threadLimitDimsValues.size() &&
+ ops.threadLimitVals.size() &&
"invalid block argument list");
int argIndex = 0;
for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
@@ -138,8 +137,8 @@ class HostEvalInfo {
if (ops.numThreads)
ops.numThreads = args[argIndex++];
- for (size_t i = 0; i < ops.threadLimitDimsValues.size(); ++i)
- ops.threadLimitDimsValues[i] = args[argIndex++];
+ for (size_t i = 0; i < ops.threadLimitVals.size(); ++i)
+ ops.threadLimitVals[i] = args[argIndex++];
}
/// Update \p clauseOps and \p ivOut with the corresponding host-evaluated
@@ -186,14 +185,12 @@ class HostEvalInfo {
/// \returns whether an update was performed. If not, these clauses were not
/// evaluated in the host device.
bool apply(mlir::omp::TeamsOperands &clauseOps) {
- if (!ops.numTeamsLower && !ops.numTeamsUpper &&
- ops.threadLimitDimsValues.empty())
+ if (!ops.numTeamsLower && !ops.numTeamsUpper && ops.threadLimitVals.empty())
return false;
clauseOps.numTeamsLower = ops.numTeamsLower;
clauseOps.numTeamsUpper = ops.numTeamsUpper;
- clauseOps.threadLimitDimsValues = ops.threadLimitDimsValues;
- clauseOps.threadLimitNumDims = ops.threadLimitNumDims;
+ clauseOps.threadLimitVals = ops.threadLimitVals;
return true;
}
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 4d3fec3b0710f..b804a14e32f0c 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -766,8 +766,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
- targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
- targetOp.getPrivateMapsAttr());
+ targetOp.getThreadLimitVals(), targetOp.getPrivateMapsAttr());
rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
newTargetOp.getRegion().begin());
rewriter.replaceOp(targetOp, targetDataOp);
@@ -1486,8 +1485,7 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
- targetOp.getPrivateNeedsBarrierAttr(),
- targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+ targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVals(),
targetOp.getPrivateMapsAttr());
auto *preTargetBlock = rewriter.createBlock(
&preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
@@ -1577,8 +1575,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
- targetOp.getPrivateNeedsBarrierAttr(),
- targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+ targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVals(),
targetOp.getPrivateMapsAttr());
auto *isolatedTargetBlock =
rewriter.createBlock(&isolatedTargetOp.getRegion(),
@@ -1658,8 +1655,7 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
- targetOp.getPrivateNeedsBarrierAttr(),
- targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+ targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVals(),
targetOp.getPrivateMapsAttr());
// Create the block for postTargetOp
auto *postTargetBlock = rewriter.createBlock(
diff --git a/flang/test/Lower/OpenMP/teams.f90 b/flang/test/Lower/OpenMP/teams.f90
index 47d379d6c2842..e5ba7070cf664 100644
--- a/flang/test/Lower/OpenMP/teams.f90
+++ b/flang/test/Lower/OpenMP/teams.f90
@@ -21,7 +21,7 @@ subroutine teams_numteams(num_teams)
integer, intent(inout) :: num_teams
! CHECK: omp.teams
- ! CHECK-SAME: num_teams( to %{{.*}}: i32)
+ ! CHECK-SAME: num_teams(to %{{.*}}: i32)
!$omp teams num_teams(4)
! CHECK: fir.call
call f1()
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 7d2810018e45f..9d4a01e9edf13 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1452,52 +1452,48 @@ class OpenMP_ThreadLimitClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
- ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims,
- Variadic<IntLikeType>:$thread_limit_dims_values
+ Variadic<IntLikeType>:$thread_limit_vals
);
let optAssemblyFormat = [{
`thread_limit` `(` custom<ThreadLimitClause>(
- $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values)
+ $thread_limit_vals, type($thread_limit_vals)
) `)`
}];
let description = [{
The `thread_limit` clause specifies the limit on the number of threads.
- With dims modifier:
- - The number of dimensions is specified by the `thread_limit_num_dims`.
- - The values for each dimension are specified by the `thread_limit_dims_values`.
- - Format: `thread_limit(dims(N): values : type)`
- - Example: `thread_limit(dims(2): %n, %m : i64)`
+ Multi-dimensional format (dims modifier):
+ - Multiple values can be specified for multi-dimensional thread limits.
+ - The number of dimensions is derived from the number of values.
+ - Values can have different integer types.
+ - Format: `thread_limit(%v1, %v2, ... : type1, type2, ...)`
+ - Example: `thread_limit(%n, %m : i32, i64)`
- Without dims modifier:
- - The number of threads is specified by the single value in `thread_limit_dims_values`.
- - Format: `thread_limit(value : type)`
- - Example: `thread_limit(%n : i64)`
+ Single value format:
+ - A single value specifies the thread limit.
+ - Format: `thread_limit(%value : type)`
+ - Example: `thread_limit(%n : i32)`
}];
let extraClassDeclaration = [{
- /// Returns true if the dims modifier is explicitly present
- bool hasThreadLimitDimsModifier() {
- return getThreadLimitNumDims().has_value() && getThreadLimitNumDims().value();
+ /// Returns true if using multi-dimensional values (more than one value)
+ bool hasThreadLimitMultiDim() {
+ return getThreadLimitVals().size() > 1;
}
- /// Returns the number of dimensions specified by dims modifier
+ /// Returns the number of dimensions specified for thread_limit
unsigned getThreadLimitDimsCount() {
- if (!hasThreadLimitDimsModifier())
- return 1;
- return static_cast<unsigned>(*getThreadLimitNumDims());
+ return getThreadLimitVals().size();
}
/// Returns the value for a specific dimension index
- /// Index must be less than getThreadLimitDimsCount()
- ::mlir::Value getThreadLimitDimensionValue(unsigned index) {
- assert(index < getThreadLimitDimsCount() &&
- "Thread limit dims index out of bounds");
- if (getThreadLimitDimsValues().empty())
- return nullptr;
- return getThreadLimitDimsValues()[index];
+ /// Index must be less than getThreadLimitVals().size()
+ ::mlir::Value getThreadLimitVal(unsigned index) {
+ assert(index < getThreadLimitVals().size() &&
+ "Thread limit index out of bounds");
+ return getThreadLimitVals()[index];
}
}];
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 4f96b1c079670..c22830f3b08ec 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2210,22 +2210,10 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
/*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
clauses.mapVars, clauses.nowait, clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
- clauses.privateNeedsBarrier, clauses.threadLimitNumDims,
- clauses.threadLimitDimsValues,
+ clauses.privateNeedsBarrier, clauses.threadLimitVals,
/*private_maps=*/nullptr);
}
-// helper for thread_limit clause restrictions
-static LogicalResult
-verifyThreadLimitClause(Operation *op,
- std::optional<IntegerAttr> threadLimitNumDims,
- OperandRange threadLimitDimsValues) {
- if (failed(verifyDimsModifier(op, threadLimitNumDims, threadLimitDimsValues)))
- return failure();
-
- return success();
-}
-
LogicalResult TargetOp::verify() {
if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
return failure();
@@ -2237,10 +2225,6 @@ LogicalResult TargetOp::verify() {
if (failed(verifyMapClause(*this, getMapVars())))
return failure();
- if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
- getThreadLimitDimsValues())))
- return failure();
-
return verifyPrivateVarsMapping(*this);
}
@@ -2258,7 +2242,7 @@ LogicalResult TargetOp::verifyRegions() {
if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
if (teamsOp.getNumTeamsLower() == hostEvalArg ||
teamsOp.getNumTeamsUpper() == hostEvalArg ||
- llvm::is_contained(teamsOp.getThreadLimitDimsValues(), hostEvalArg))
+ llvm::is_contained(teamsOp.getThreadLimitVals(), hostEvalArg))
continue;
return emitOpError() << "host_eval argument only legal as 'num_teams' "
@@ -2647,7 +2631,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
clauses.reductionVars,
makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
makeArrayAttr(ctx, clauses.reductionSyms),
- clauses.threadLimit);
+ clauses.threadLimitVals);
}
// Verify num_teams clause
@@ -2702,11 +2686,6 @@ LogicalResult TeamsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
- // Check for thread_limit clause restrictions
- if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
- getThreadLimitDimsValues())))
- return failure();
-
return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
getReductionByref());
}
@@ -4608,7 +4587,7 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
p << " : " << upperBoundType;
} else {
// Upper only: to upper : type
- p << "to ";
+ p << " to ";
p.printOperand(upperBound);
p << " : " << upperBoundType;
}
@@ -4619,31 +4598,24 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
// Parser and printer for thread_limit clause
//===----------------------------------------------------------------------===//
static ParseResult
-parseThreadLimitClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+parseThreadLimitClause(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
SmallVectorImpl<Type> &types) {
- // Try parsing with dims modifier: dims(N): values : type
- if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
- return success();
- }
-
- // Without dims modifier: value : type
- OpAsmParser::UnresolvedOperand singleValue;
- Type singleType;
- if (parser.parseOperand(singleValue) || parser.parseColon() ||
- parser.parseType(singleType)) {
+ // Parse comma-separated list of values with their types
+ // Format: %v1, %v2, ... : type1, type2, ...
+ if (parser.parseOperandList(values) || parser.parseColon() ||
+ parser.parseTypeList(types)) {
return failure();
}
- values.push_back(singleValue);
- types.push_back(singleType);
return success();
}
static void printThreadLimitClause(OpAsmPrinter &p, Operation *op,
- IntegerAttr dimsAttr, OperandRange values,
- TypeRange types) {
- // Multidimensional: dims(N): values : type
- printDimsModifierWithValues(p, dimsAttr, values, types);
+ OperandRange values, TypeRange types) {
+ // Print values with their types
+ llvm::interleaveComma(values, p, [&](Value v) { p << v; });
+ p << " : ";
+ llvm::interleaveComma(types, p, [&](Type t) { p << t; });
}
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 7b2426e860d8d..7abbeaedc446d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,6 +380,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.hasNumTeamsMultiDim())
result = todo("num_teams with multi-dimensional values");
};
+ auto checkThreadLimitMultiDim = [&todo](auto op, LogicalResult &result) {
+ if (op.hasThreadLimitMultiDim())
+ result = todo("thread_limit with multi-dimensional values");
+ };
LogicalResult result = success();
llvm::TypeSwitch<Operation &>(op)
@@ -405,6 +409,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkAllocate(op, result);
checkPrivate(op, result);
checkNumTeamsMultiDim(op, result);
+ checkThreadLimitMultiDim(op, result);
})
.Case([&](omp::TaskOp op) {
checkAllocate(op, result);
@@ -442,6 +447,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkAllocate(op, result);
checkBare(op, result);
checkInReduction(op, result);
+ checkThreadLimitMultiDim(op, result);
})
.Default([](Operation &) {
// Assume all clauses for an operation can be translated unless they are
@@ -2075,8 +2081,8 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
llvm::Value *threadLimit = nullptr;
- if (Value threadLimitVar = op.getThreadLimitDimensionValue(0))
- threadLimit = moduleTranslation.lookupValue(threadLimitVar);
+ if (!op.getThreadLimitVals().empty())
+ threadLimit = moduleTranslation.lookupValue(op.getThreadLimitVal(0));
llvm::Value *ifExpr = nullptr;
if (Value ifVar = op.getIfExpr())
@@ -6045,7 +6051,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
numTeamsLower = hostEvalVar;
else if (teamsOp.getNumTeamsUpper() == blockArg)
numTeamsUpper = hostEvalVar;
- else if (teamsOp.getThreadLimitDimensionValue(0) == blockArg)
+ else if (!teamsOp.getThreadLimitVals().empty() &&
+ teamsOp.getThreadLimitVal(0) == blockArg)
threadLimit = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6165,7 +6172,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
numTeamsLower = teamsOp.getNumTeamsLower();
numTeamsUpper = teamsOp.getNumTeamsUpper();
- threadLimit = teamsOp.getThreadLimitDimensionValue(0);
+ if (!teamsOp.getThreadLimitVals().empty())
+ threadLimit = teamsOp.getThreadLimitVal(0);
}
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
@@ -6210,8 +6218,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// Extract 'thread_limit' clause from 'target' and 'teams' directives.
int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
- setMaxValueFromClause(targetOp.getThreadLimitDimensionValue(0),
- targetThreadLimitVal);
+ if (!targetOp.getThreadLimitVals().empty())
+ setMaxValueFromClause(targetOp.getThreadLimitVal(0), targetThreadLimitVal);
setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
// Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
@@ -6290,9 +6298,11 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
// TODO: Handle constant 'if' clauses.
- if (Value targetThreadLimit = targetOp.getThreadLimitDimensionValue(0))
+ if (!targetOp.getThreadLimitVals().empty()) {
+ Value targetThreadLimit = targetOp.getThreadLimitVal(0);
attrs.TargetThreadLimit.front() =
moduleTranslation.lookupValue(targetThreadLimit);
+ }
if (numTeamsLower)
attrs.MinTeams = moduleTranslation.lookupValue(numTeamsLower);
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 649c0fde35ee6..bb882db73cbab 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1489,139 +1489,6 @@ func.func @omp_teams_num_teams2(%lb : i32, %ub : i16) {
// -----
-func.func @omp_teams_thread_limit_dims_mismatch() {
- omp.target {
- %v0 = arith.constant 1 : i32
- %v1 = arith.constant 2 : i32
- // expected-error @below {{dims(3) specified but 2 values provided}}
- "omp.teams" (%v0, %v1) ({
- omp.terminator
- }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
- omp.terminator
- }
- return
-}
-
-// -----
-
-func.func @omp_teams_thread_limit_dims_with_scalar() {
- omp.target {
- %v0 = arith.constant 1 : i32
- %v1 = arith.constant 2 : i32
- %tl = arith.constant 4 : i32
- // expected-error @below {{dims(2) specified but 3 values provided}}
- "omp.teams" (%v0, %v1, %tl) ({
- omp.terminator
- }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> ()
- omp.terminator
- }
- return
-}
-
-// -----
-
-func.func @omp_teams_thread_limit_dims_no_values() {
- omp.target {
- // expected-error @below {{dims modifier requires values to be specified}}
- "omp.teams" () ({
- omp.terminator
- }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> ()
- omp.terminator
- }
- return
-}
-
-// -----
-
-func.func @omp_teams_thread_limit_values_without_dims() {
- omp.target {
- %v0 = arith.constant 1 : i32
- %v1 = arith.constant 2 : i32
- // expected-error @below {{dims values can only be specified with dims modifier}}
- "omp.teams" (%v0, %v1) ({
- omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
- omp.terminator
- }
- return
-}
-
-// -----
-
-func.func @omp_teams_thread_limit_dims_type_mismatch() {
- omp.target {
- %v0 = arith.constant 1 : i32
- %v1 = arith.constant 2 : i64
- // expected-error @below {{dims modifier requires all values to have the same type}}
- "omp.teams" (%v0, %v1) ({
- omp.terminator
- }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i64) -> ()
- omp.terminator
- }
- return
-}
-
-// -----
-
-func.func @omp_target_thread_limit_dims_mismatch() {
- %v0 = arith.constant 1 : i32
- %v1 = arith.constant 2 : i32
- // expected-error @below {{dims(3) specified but 2 values provided}}
- "omp.target" (%v0, %v1) ({
- omp.terminator
- }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
- return
-}
-
-// -----
-
-func.func @omp_target_thread_limit_dims_with_scalar() {
- %v0 = arith.constant 1 : i32
- %v1 = arith.constant 2 : i32
- %tl = arith.constant 4 : i32
- // expected-error @below {{dims(2) specified but 3 values provided}}
- "omp.target" (%v0, %v1, %tl) ({
- omp.terminator
- }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> ()
- return
-}
-
-// -----
-
-func.func @omp_target_thread_limit_dims_no_values() {
- // expected-error @below {{dims modifier requires values to be specified}}
- "omp.target" () ({
- omp.terminator
- }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0>} : () -> ()
- return
-}
-
-// -----
-
-func.func @omp_target_thread_limit_values_without_dims() {
- %v0 = arith.constant 1 : i32
- %v1 = arith.constant 2 : i32
- // expected-error @below {{dims values can only be specified with dims modifier}}
- "omp.target" (%v0, %v1) ({
- omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
- return
-}
-
-// -----
-
-func.func @omp_target_thread_limit_dims_type_mismatch() {
- %v0 = arith.constant 1 : i32
- %v1 = arith.constant 2 : i64
- // expected-error @below {{dims modifier requires all values to have the same type}}
- "omp.target" (%v0, %v1) ({
- omp.terminator
- }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i64) -> ()
- return
-}
-
-// -----
-
func.func @omp_sections(%data_var : memref<i32>) -> () {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.sections" (%data_var) ({
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 51b1eed766ac3..4e5acca796584 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1103,7 +1103,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
omp.terminator
}
- // CHECK: omp.teams num_teams(to %{{.+}} : i32)
+ // CHECK: omp.teams num_teams( to %{{.+}} : i32)
omp.teams num_teams(to %ub : i32) {
// CHECK: omp.terminator
omp.terminator
@@ -1136,8 +1136,15 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
omp.terminator
}
- // CHECK: omp.teams thread_limit(dims(2): %{{.*}}, %{{.*}} : i32)
- omp.teams thread_limit(dims(2): %lb, %ub : i32) {
+ // CHECK: omp.teams thread_limit(%{{.*}}, %{{.*}} : i32, i32)
+ omp.teams thread_limit(%lb, %ub : i32, i32) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+
+ // Test thread_limit with mixed types.
+ // CHECK: omp.teams thread_limit(%{{.*}}, %{{.*}}, %{{.*}} : i32, i64, i16)
+ omp.teams thread_limit(%lb, %ub64, %ub16 : i32, i64, i16) {
// CHECK: omp.terminator
omp.terminator
}
@@ -3090,7 +3097,7 @@ func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?
func.func @omp_target_host_eval(%x : i32) {
// CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
- // CHECK: omp.teams num_teams(to %[[HOST_ARG]] : i32)
+ // CHECK: omp.teams num_teams( to %[[HOST_ARG]] : i32)
// CHECK-SAME: thread_limit(%[[HOST_ARG]] : i32)
omp.target host_eval(%x -> %arg0 : i32) {
omp.teams num_teams( to %arg0 : i32) thread_limit(%arg0 : i32) {
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 3681ce38bd523..c766cc9568b4f 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -452,6 +452,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
// -----
+llvm.func @teams_thread_limit_multi_dim(%lb : i32, %ub : i32) {
+ // expected-error at below {{not yet implemented: Unhandled clause thread_limit with multi-dimensional values in omp.teams operation}}
+ // expected-error at below {{LLVM Translation failed for operation: omp.teams}}
+ omp.teams thread_limit(%lb, %ub : i32, i32) {
+ omp.terminator
+ }
+ llvm.return
+}
+
+// -----
+
llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
// expected-error at below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}}
// expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}
>From bea50e20b09934bad867e6b847517785dedff923 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Sat, 17 Jan 2026 10:21:49 +0530
Subject: [PATCH 9/9] remove custom parser/printer for dims
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 14 +++++------
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 24 -------------------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 16 ++++++-------
3 files changed, 14 insertions(+), 40 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 9d4a01e9edf13..1970e2115003f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1456,9 +1456,7 @@ class OpenMP_ThreadLimitClauseSkip<
);
let optAssemblyFormat = [{
- `thread_limit` `(` custom<ThreadLimitClause>(
- $thread_limit_vals, type($thread_limit_vals)
- ) `)`
+ `thread_limit` `(` $thread_limit_vals `:` type($thread_limit_vals) `)`
}];
let description = [{
@@ -1488,12 +1486,12 @@ class OpenMP_ThreadLimitClauseSkip<
return getThreadLimitVals().size();
}
- /// Returns the value for a specific dimension index
- /// Index must be less than getThreadLimitVals().size()
- ::mlir::Value getThreadLimitVal(unsigned index) {
- assert(index < getThreadLimitVals().size() &&
+ /// Returns the value for a specific dimension
+ /// dim must be less than getThreadLimitDimsCount()
+ ::mlir::Value getThreadLimit(unsigned dim = 0) {
+ assert(dim < getThreadLimitDimsCount() &&
"Thread limit index out of bounds");
- return getThreadLimitVals()[index];
+ return getThreadLimitVals()[dim];
}
}];
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c22830f3b08ec..d5532f959ae1b 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4594,30 +4594,6 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
}
}
-//===----------------------------------------------------------------------===//
-// Parser and printer for thread_limit clause
-//===----------------------------------------------------------------------===//
-static ParseResult
-parseThreadLimitClause(OpAsmParser &parser,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- SmallVectorImpl<Type> &types) {
- // Parse comma-separated list of values with their types
- // Format: %v1, %v2, ... : type1, type2, ...
- if (parser.parseOperandList(values) || parser.parseColon() ||
- parser.parseTypeList(types)) {
- return failure();
- }
- return success();
-}
-
-static void printThreadLimitClause(OpAsmPrinter &p, Operation *op,
- OperandRange values, TypeRange types) {
- // Print values with their types
- llvm::interleaveComma(values, p, [&](Value v) { p << v; });
- p << " : ";
- llvm::interleaveComma(types, p, [&](Type t) { p << t; });
-}
-
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 7abbeaedc446d..a25654926ca01 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,7 +380,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
if (op.hasNumTeamsMultiDim())
result = todo("num_teams with multi-dimensional values");
};
- auto checkThreadLimitMultiDim = [&todo](auto op, LogicalResult &result) {
+ auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
if (op.hasThreadLimitMultiDim())
result = todo("thread_limit with multi-dimensional values");
};
@@ -409,7 +409,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkAllocate(op, result);
checkPrivate(op, result);
checkNumTeamsMultiDim(op, result);
- checkThreadLimitMultiDim(op, result);
+ checkThreadLimit(op, result);
})
.Case([&](omp::TaskOp op) {
checkAllocate(op, result);
@@ -447,7 +447,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
checkAllocate(op, result);
checkBare(op, result);
checkInReduction(op, result);
- checkThreadLimitMultiDim(op, result);
+ checkThreadLimit(op, result);
})
.Default([](Operation &) {
// Assume all clauses for an operation can be translated unless they are
@@ -2082,7 +2082,7 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
llvm::Value *threadLimit = nullptr;
if (!op.getThreadLimitVals().empty())
- threadLimit = moduleTranslation.lookupValue(op.getThreadLimitVal(0));
+ threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0));
llvm::Value *ifExpr = nullptr;
if (Value ifVar = op.getIfExpr())
@@ -6052,7 +6052,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
else if (teamsOp.getNumTeamsUpper() == blockArg)
numTeamsUpper = hostEvalVar;
else if (!teamsOp.getThreadLimitVals().empty() &&
- teamsOp.getThreadLimitVal(0) == blockArg)
+ teamsOp.getThreadLimit(0) == blockArg)
threadLimit = hostEvalVar;
else
llvm_unreachable("unsupported host_eval use");
@@ -6173,7 +6173,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
numTeamsLower = teamsOp.getNumTeamsLower();
numTeamsUpper = teamsOp.getNumTeamsUpper();
if (!teamsOp.getThreadLimitVals().empty())
- threadLimit = teamsOp.getThreadLimitVal(0);
+ threadLimit = teamsOp.getThreadLimit(0);
}
if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
@@ -6219,7 +6219,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// Extract 'thread_limit' clause from 'target' and 'teams' directives.
int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
if (!targetOp.getThreadLimitVals().empty())
- setMaxValueFromClause(targetOp.getThreadLimitVal(0), targetThreadLimitVal);
+ setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
// Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
@@ -6299,7 +6299,7 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
// TODO: Handle constant 'if' clauses.
if (!targetOp.getThreadLimitVals().empty()) {
- Value targetThreadLimit = targetOp.getThreadLimitVal(0);
+ Value targetThreadLimit = targetOp.getThreadLimit(0);
attrs.TargetThreadLimit.front() =
moduleTranslation.lookupValue(targetThreadLimit);
}
More information about the llvm-branch-commits
mailing list