[Mlir-commits] [mlir] [OpenMP][MLIR] Add num_teams clause with dims modifier support (PR #169883)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Jan 13 08:26:24 PST 2026
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/169883
>From 842eb9fcb715dffd67cab8d183b7f6d20232f03e Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 28 Nov 2025 13:37:14 +0530
Subject: [PATCH 1/9] [OpenMP][MLIR] Add num_teams clause with dims modifier
support
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 72 +++++++++++++++++++
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 3 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 64 +++++++++++++++++
mlir/test/Dialect/OpenMP/invalid.mlir | 4 +-
mlir/test/Dialect/OpenMP/ops.mlir | 6 ++
5 files changed, 146 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 05e2ee4e5632b..1341bba3ad85e 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1532,4 +1532,76 @@ class OpenMP_UseDevicePtrClauseSkip<
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V6.2: Multidimensional `num_teams` clause with dims modifier
+//===----------------------------------------------------------------------===//
+
+class OpenMP_NumTeamsMultiDimClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
+ Variadic<AnyInteger>:$num_teams_values
+ );
+
+ let optAssemblyFormat = [{
+ `num_teams_multi_dim` `(` custom<NumTeamsMultiDimClause>($num_teams_dims,
+ $num_teams_values,
+ type($num_teams_values)) `)`
+ }];
+
+ let description = [{
+ The `num_teams_multi_dim` clause with dims modifier support specifies the limit on
+ the number of teams to be created in a multidimensional team space.
+
+ The dims modifier for the num_teams_multi_dim clause specifies the number of
+ dimensions for the league space (team space) that the clause arranges.
+ The dimensions argument in the dims modifier specifies the number of
+ dimensions and determines the length of the list argument. The list items
+ are specified in ascending order according to the ordinal number of the
+ dimensions (dimension 0, 1, 2, ..., N-1).
+
+ - If `dims` is not specified: The space is unidimensional (1D) with a single value
+ - If `dims(1)` is specified: The space is explicitly unidimensional (1D)
+ - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D)
+
+ **Examples:**
+ - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a
+ 3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2.
+ - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt.
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasDimsModifier() {
+ return getNumTeamsDims().has_value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ /// Returns 1 if dims modifier is not present (unidimensional by default)
+ unsigned getNumDimensions() {
+ if (!hasDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumTeamsDims());
+ }
+
+ /// Returns all dimension values as an operand range
+ ::mlir::OperandRange getDimensionValues() {
+ return getNumTeamsValues();
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumDimensions()
+ ::mlir::Value getDimensionValue(unsigned index) {
+ assert(index < getDimensionValues().size() &&
+ "Dimension index out of bounds");
+ return getDimensionValues()[index];
+ }
+ }];
+}
+
+def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>;
+
#endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 1fcd7b3c23e10..5e399d12b98ad 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,7 +241,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
], clauses = [
OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
- OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
+ OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
+ OpenMP_ThreadLimitClause
], singleRegion = true> {
let summary = "teams construct";
let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 1af525ba92bd0..27ea6da14b182 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2626,6 +2626,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+ clauses.numTeamsDims, clauses.numTeamsValues,
/*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
@@ -4465,6 +4466,69 @@ LogicalResult WorkdistributeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for NumTeamsMultiDim Clause (with dims modifier)
+//===----------------------------------------------------------------------===//
+// num_teams_multidim ::= `num_teams` `(` [`dims` `(` dim-count `)` `:`] values
+// `)` Example: num_teams(dims(3): %v0, %v1, %v2 : i32, i32, i32) Or:
+// num_teams(%v : i32)
+static ParseResult parseNumTeamsMultiDimClause(
+ OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types) {
+ std::optional<int64_t> dims;
+ // Try to parse optional dims modifier: dims(N):
+ if (succeeded(parser.parseOptionalKeyword("dims"))) {
+ int64_t dimsValue;
+ if (parser.parseLParen() || parser.parseInteger(dimsValue) ||
+ parser.parseRParen() || parser.parseColon()) {
+ return failure();
+ }
+ dims = dimsValue;
+ }
+ // Parse the operand list
+ if (parser.parseOperandList(values))
+ return failure();
+ // Parse colon and types
+ if (parser.parseColon() || parser.parseTypeList(types))
+ return failure();
+
+ // Verify dims matches number of values if specified
+ if (dims.has_value() && values.size() != static_cast<size_t>(*dims)) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "dims(" << *dims << ") specified but " << values.size()
+ << " values provided";
+ }
+
+ // If dims not specified but we have values, it's implicitly unidimensional
+ if (!dims.has_value() && values.size() != 1) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected 1 value without dims modifier, got " << values.size();
+ }
+
+ // Convert to IntegerAttr
+ if (dims.has_value()) {
+ dimsAttr = parser.getBuilder().getI64IntegerAttr(*dims);
+ }
+ return success();
+}
+
+static void printNumTeamsMultiDimClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr,
+ OperandRange values, TypeRange types) {
+ // Print dims modifier if present
+ if (dimsAttr) {
+ p << "dims(" << dimsAttr.getInt() << "): ";
+ }
+
+ // Print operands
+ p.printOperands(values);
+
+ // Print types
+ p << " : ";
+ llvm::interleaveComma(types, p);
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index af24d969064ab..62619f07d6573 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.teams" (%data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
omp.terminator
}
return
@@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0>} : (i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
omp.terminator
}
return
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 816df56ecc5a5..c1bd3cd0fa446 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1109,6 +1109,12 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
omp.terminator
}
+ // CHECK: omp.teams num_teams_multi_dim(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32)
+ omp.teams num_teams_multi_dim(dims(3): %lb, %ub, %ub : i32, i32, i32) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+
// Test if.
// CHECK: omp.teams if(%{{.+}})
omp.teams if(%if_cond) {
>From 2afde3990fcef5cc938fb3f2ede3ed564a95af99 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 28 Nov 2025 14:20:39 +0530
Subject: [PATCH 2/9] fix comment
---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 27ea6da14b182..a66dd37009e7d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4469,9 +4469,8 @@ LogicalResult WorkdistributeOp::verify() {
//===----------------------------------------------------------------------===//
// Parser and printer for NumTeamsMultiDim Clause (with dims modifier)
//===----------------------------------------------------------------------===//
-// num_teams_multidim ::= `num_teams` `(` [`dims` `(` dim-count `)` `:`] values
-// `)` Example: num_teams(dims(3): %v0, %v1, %v2 : i32, i32, i32) Or:
-// num_teams(%v : i32)
+// num_teams_multi_dim(dims(3): %v0, %v1, %v2 : i32, i32, i32) Or:
+// num_teams_multi_dim(%v : i32)
static ParseResult parseNumTeamsMultiDimClause(
OpAsmParser &parser, IntegerAttr &dimsAttr,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
>From 11124c62ca3033b646ae44639ce11172c5aa0c1d Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 1 Dec 2025 12:56:28 +0530
Subject: [PATCH 3/9] Comments fix
---
mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td | 2 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 3 ++-
2 files changed, 3 insertions(+), 2 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 1341bba3ad85e..8ea3dedf1d3ac 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1533,7 +1533,7 @@ class OpenMP_UseDevicePtrClauseSkip<
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
//===----------------------------------------------------------------------===//
-// V6.2: Multidimensional `num_teams` clause with dims modifier
+// V6.1: Multidimensional `num_teams` clause with dims modifier
//===----------------------------------------------------------------------===//
class OpenMP_NumTeamsMultiDimClauseSkip<
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index a66dd37009e7d..1dee3121971d5 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4502,7 +4502,8 @@ static ParseResult parseNumTeamsMultiDimClause(
// If dims not specified but we have values, it's implicitly unidimensional
if (!dims.has_value() && values.size() != 1) {
return parser.emitError(parser.getCurrentLocation())
- << "expected 1 value without dims modifier, got " << values.size();
+ << "expected 1 value without dims modifier, but got "
+ << values.size() << " values";
}
// Convert to IntegerAttr
>From 5e286bc452d42ea8d874f10b192fa4220c9a4613 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 1 Dec 2025 15:02:21 +0530
Subject: [PATCH 4/9] Use DimsModifier for custom assembly parser and printer
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 4 ++--
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 20 +++++++++----------
2 files changed, 12 insertions(+), 12 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 8ea3dedf1d3ac..468c5ce132aaf 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1533,7 +1533,7 @@ class OpenMP_UseDevicePtrClauseSkip<
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
//===----------------------------------------------------------------------===//
-// V6.1: Multidimensional `num_teams` clause with dims modifier
+// V6.1: `num_teams` clause with dims modifier
//===----------------------------------------------------------------------===//
class OpenMP_NumTeamsMultiDimClauseSkip<
@@ -1547,7 +1547,7 @@ class OpenMP_NumTeamsMultiDimClauseSkip<
);
let optAssemblyFormat = [{
- `num_teams_multi_dim` `(` custom<NumTeamsMultiDimClause>($num_teams_dims,
+ `num_teams_multi_dim` `(` custom<DimsModifier>($num_teams_dims,
$num_teams_values,
type($num_teams_values)) `)`
}];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 1dee3121971d5..be913fa9272c1 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4467,14 +4467,14 @@ LogicalResult WorkdistributeOp::verify() {
}
//===----------------------------------------------------------------------===//
-// Parser and printer for NumTeamsMultiDim Clause (with dims modifier)
+// Parser and printer for Clauses with dims modifier
//===----------------------------------------------------------------------===//
-// num_teams_multi_dim(dims(3): %v0, %v1, %v2 : i32, i32, i32) Or:
-// num_teams_multi_dim(%v : i32)
-static ParseResult parseNumTeamsMultiDimClause(
- OpAsmParser &parser, IntegerAttr &dimsAttr,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- SmallVectorImpl<Type> &types) {
+// clause_name(dims(3): %v0, %v1, %v2 : i32, i32, i32)
+// clause_name(%v : i32)
+static ParseResult
+parseDimsModifier(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types) {
std::optional<int64_t> dims;
// Try to parse optional dims modifier: dims(N):
if (succeeded(parser.parseOptionalKeyword("dims"))) {
@@ -4513,9 +4513,9 @@ static ParseResult parseNumTeamsMultiDimClause(
return success();
}
-static void printNumTeamsMultiDimClause(OpAsmPrinter &p, Operation *op,
- IntegerAttr dimsAttr,
- OperandRange values, TypeRange types) {
+static void printDimsModifier(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types) {
// Print dims modifier if present
if (dimsAttr) {
p << "dims(" << dimsAttr.getInt() << "): ";
>From 8144eeedc17a01877c3e5e3204d9c9232c798d95 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 10 Dec 2025 13:13:17 +0530
Subject: [PATCH 5/9] use dims modifer in main num_teams clause itself instead
of creating new clause
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 128 +++++------
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 3 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 204 ++++++++++++++----
mlir/test/Dialect/OpenMP/invalid.mlir | 77 ++++++-
mlir/test/Dialect/OpenMP/ops.mlir | 4 +-
5 files changed, 284 insertions(+), 132 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 468c5ce132aaf..0b17a62c88d92 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -974,22 +974,62 @@ class OpenMP_NumTeamsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
+ Variadic<AnyInteger>:$num_teams_values,
Optional<AnyInteger>:$num_teams_lower,
Optional<AnyInteger>:$num_teams_upper
);
let optAssemblyFormat = [{
- `num_teams` `(` ( $num_teams_lower^ `:` type($num_teams_lower) )? `to`
- $num_teams_upper `:` type($num_teams_upper) `)`
+ `num_teams` `(` custom<NumTeamsClause>(
+ $num_teams_dims, $num_teams_values, type($num_teams_values),
+ $num_teams_lower, type($num_teams_lower),
+ $num_teams_upper, type($num_teams_upper)
+ ) `)`
}];
let description = [{
- The optional `num_teams_upper` and `num_teams_lower` arguments specify the
- limit on the number of teams to be created. If only the upper bound is
- specified, it acts as if the lower bound was set to the same value. It is
- not allowed to set `num_teams_lower` if `num_teams_upper` is not specified.
- They define a closed range, where both the lower and upper bounds are
- included.
+ The `num_teams` clause specifies the bounds on the league space formed by the
+ construct on which it appears.
+
+ With dims modifier: (OpenMP 6.1 requirement)
+ - Uses `num_teams_dims` (dimension count) and `num_teams_values` (upper bounds list)
+ - Specifies upper bounds for each dimension (all must have same type)
+ - Format: `num_teams(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+ - Example: `num_teams(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+ Without dims modifier:
+ - Uses `num_teams_upper` and optional `num_teams_lower`
+ - If lower bound not specified, it defaults to upper bound value
+ - Format: `num_teams(lower : type to upper : type)` or `num_teams(to upper : type)`
+ - Example: `num_teams(%lb : i32 to %ub : i32)` or `num_teams(to %ub : i32)`
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasDimsModifier() {
+ return getNumTeamsDims().has_value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ unsigned getNumDimensions() {
+ if (!hasDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumTeamsDims());
+ }
+
+ /// Returns all dimension values as an operand range
+ ::mlir::OperandRange getDimensionValues() {
+ return getNumTeamsValues();
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumDimensions()
+ ::mlir::Value getDimensionValue(unsigned index) {
+ assert(index < getDimensionValues().size() &&
+ "Dimension index out of bounds");
+ return getDimensionValues()[index];
+ }
}];
}
@@ -1532,76 +1572,4 @@ class OpenMP_UseDevicePtrClauseSkip<
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
-//===----------------------------------------------------------------------===//
-// V6.1: `num_teams` clause with dims modifier
-//===----------------------------------------------------------------------===//
-
-class OpenMP_NumTeamsMultiDimClauseSkip<
- bit traits = false, bit arguments = false, bit assemblyFormat = false,
- bit description = false, bit extraClassDeclaration = false
- > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
- extraClassDeclaration> {
- let arguments = (ins
- ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
- Variadic<AnyInteger>:$num_teams_values
- );
-
- let optAssemblyFormat = [{
- `num_teams_multi_dim` `(` custom<DimsModifier>($num_teams_dims,
- $num_teams_values,
- type($num_teams_values)) `)`
- }];
-
- let description = [{
- The `num_teams_multi_dim` clause with dims modifier support specifies the limit on
- the number of teams to be created in a multidimensional team space.
-
- The dims modifier for the num_teams_multi_dim clause specifies the number of
- dimensions for the league space (team space) that the clause arranges.
- The dimensions argument in the dims modifier specifies the number of
- dimensions and determines the length of the list argument. The list items
- are specified in ascending order according to the ordinal number of the
- dimensions (dimension 0, 1, 2, ..., N-1).
-
- - If `dims` is not specified: The space is unidimensional (1D) with a single value
- - If `dims(1)` is specified: The space is explicitly unidimensional (1D)
- - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D)
-
- **Examples:**
- - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a
- 3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2.
- - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt.
- }];
-
- let extraClassDeclaration = [{
- /// Returns true if the dims modifier is explicitly present
- bool hasDimsModifier() {
- return getNumTeamsDims().has_value();
- }
-
- /// Returns the number of dimensions specified by dims modifier
- /// Returns 1 if dims modifier is not present (unidimensional by default)
- unsigned getNumDimensions() {
- if (!hasDimsModifier())
- return 1;
- return static_cast<unsigned>(*getNumTeamsDims());
- }
-
- /// Returns all dimension values as an operand range
- ::mlir::OperandRange getDimensionValues() {
- return getNumTeamsValues();
- }
-
- /// Returns the value for a specific dimension index
- /// Index must be less than getNumDimensions()
- ::mlir::Value getDimensionValue(unsigned index) {
- assert(index < getDimensionValues().size() &&
- "Dimension index out of bounds");
- return getDimensionValues()[index];
- }
- }];
-}
-
-def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>;
-
#endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 5e399d12b98ad..1fcd7b3c23e10 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,8 +241,7 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
], clauses = [
OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
- OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
- OpenMP_ThreadLimitClause
+ OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
], singleRegion = true> {
let summary = "teams construct";
let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index be913fa9272c1..2d95f743fcd08 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2625,8 +2625,8 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
- clauses.numTeamsDims, clauses.numTeamsValues,
+ clauses.ifExpr, clauses.numTeamsDims, clauses.numTeamsValues,
+ clauses.numTeamsLower, clauses.numTeamsUpper,
/*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
@@ -2648,14 +2648,57 @@ LogicalResult TeamsOp::verify() {
"in any OpenMP dialect operations");
// Check for num_teams clause restrictions
- if (auto numTeamsLowerBound = getNumTeamsLower()) {
- auto numTeamsUpperBound = getNumTeamsUpper();
- if (!numTeamsUpperBound)
- return emitError("expected num_teams upper bound to be defined if the "
- "lower bound is defined");
- if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
+ auto numTeamsDims = getNumTeamsDims();
+ auto numTeamsValues = getNumTeamsValues();
+ auto numTeamsLower = getNumTeamsLower();
+ auto numTeamsUpper = getNumTeamsUpper();
+
+ // Cannot use both dims modifier and unidimensional style
+ if (numTeamsDims.has_value() && (numTeamsLower || numTeamsUpper)) {
+ return emitError(
+ "num_teams with dims modifier cannot be used together with "
+ "lower/upper bounds (unidimensional style)");
+ }
+
+ // With dims modifier (multidimensional)
+ if (numTeamsDims.has_value()) {
+ if (numTeamsValues.empty()) {
+ return emitError(
+ "num_teams dims modifier requires values to be specified");
+ }
+
+ if (numTeamsValues.size() != static_cast<size_t>(*numTeamsDims)) {
+ return emitError("num_teams dims(")
+ << *numTeamsDims << ") specified but " << numTeamsValues.size()
+ << " values provided";
+ }
+
+ // All values must have the same type
+ if (!numTeamsValues.empty()) {
+ Type firstType = numTeamsValues.front().getType();
+ for (auto value : numTeamsValues) {
+ if (value.getType() != firstType) {
+ return emitError(
+ "num_teams dims modifier requires all values to have "
+ "the same type");
+ }
+ }
+ }
+ } else {
+ // Without dims modifier
+ if (!numTeamsValues.empty()) {
return emitError(
- "expected num_teams upper bound and lower bound to be the same type");
+ "num_teams values can only be specified with dims modifier");
+ }
+
+ if (numTeamsLower) {
+ if (!numTeamsUpper)
+ return emitError("expected num_teams upper bound to be defined if the "
+ "lower bound is defined");
+ if (numTeamsLower.getType() != numTeamsUpper.getType())
+ return emitError("expected num_teams upper bound and lower bound to be "
+ "the same type");
+ }
}
// Check for allocate clause restrictions
@@ -4467,66 +4510,133 @@ LogicalResult WorkdistributeOp::verify() {
}
//===----------------------------------------------------------------------===//
-// Parser and printer for Clauses with dims modifier
+// Helper: Parse dims modifier with values
+//===----------------------------------------------------------------------===//
+// Parses: dims(N): values : type (single type for all values)
+static ParseResult parseDimsModifierWithValues(
+ OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types) {
+ if (failed(parser.parseOptionalKeyword("dims"))) {
+ return failure();
+ }
+
+ // Parse (N): values : type
+ int64_t dimsValue;
+ if (parser.parseLParen() || parser.parseInteger(dimsValue) ||
+ parser.parseRParen() || parser.parseColon()) {
+ return failure();
+ }
+
+ if (parser.parseOperandList(values) || parser.parseColon()) {
+ return failure();
+ }
+
+ // Parse single type (all values have same type)
+ Type valueType;
+ if (parser.parseType(valueType)) {
+ return failure();
+ }
+
+ // Fill types vector with same type for all values
+ types.assign(values.size(), valueType);
+
+ dimsAttr = parser.getBuilder().getI64IntegerAttr(dimsValue);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_teams clause with dims modifier
//===----------------------------------------------------------------------===//
-// clause_name(dims(3): %v0, %v1, %v2 : i32, i32, i32)
-// clause_name(%v : i32)
static ParseResult
-parseDimsModifier(OpAsmParser &parser, IntegerAttr &dimsAttr,
- SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
- SmallVectorImpl<Type> &types) {
- std::optional<int64_t> dims;
- // Try to parse optional dims modifier: dims(N):
- if (succeeded(parser.parseOptionalKeyword("dims"))) {
- int64_t dimsValue;
- if (parser.parseLParen() || parser.parseInteger(dimsValue) ||
- parser.parseRParen() || parser.parseColon()) {
+parseNumTeamsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types,
+ std::optional<OpAsmParser::UnresolvedOperand> &lowerBound,
+ Type &lowerBoundType,
+ std::optional<OpAsmParser::UnresolvedOperand> &upperBound,
+ Type &upperBoundType) {
+
+ // Format: num_teams(dims(N): values : type)
+ if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+ return success();
+ }
+
+ // Format: num_teams(to upper : type)
+ if (succeeded(parser.parseOptionalKeyword("to"))) {
+ OpAsmParser::UnresolvedOperand upperOperand;
+ if (parser.parseOperand(upperOperand) || parser.parseColon() ||
+ parser.parseType(upperBoundType)) {
return failure();
}
- dims = dimsValue;
+ upperBound = upperOperand;
+ return success();
}
- // Parse the operand list
- if (parser.parseOperandList(values))
- return failure();
- // Parse colon and types
- if (parser.parseColon() || parser.parseTypeList(types))
- return failure();
- // Verify dims matches number of values if specified
- if (dims.has_value() && values.size() != static_cast<size_t>(*dims)) {
- return parser.emitError(parser.getCurrentLocation())
- << "dims(" << *dims << ") specified but " << values.size()
- << " values provided";
+ // Format: num_teams(lower : type to upper : type)
+ OpAsmParser::UnresolvedOperand lowerOperand;
+ if (parser.parseOperand(lowerOperand) || parser.parseColon() ||
+ parser.parseType(lowerBoundType)) {
+ return failure();
}
- // If dims not specified but we have values, it's implicitly unidimensional
- if (!dims.has_value() && values.size() != 1) {
+ if (failed(parser.parseKeyword("to"))) {
return parser.emitError(parser.getCurrentLocation())
- << "expected 1 value without dims modifier, but got "
- << values.size() << " values";
+ << "expected 'to' keyword in num_teams clause";
}
- // Convert to IntegerAttr
- if (dims.has_value()) {
- dimsAttr = parser.getBuilder().getI64IntegerAttr(*dims);
+ OpAsmParser::UnresolvedOperand upperOperand;
+ if (parser.parseOperand(upperOperand) || parser.parseColon() ||
+ parser.parseType(upperBoundType)) {
+ return failure();
}
+
+ lowerBound = lowerOperand;
+ upperBound = upperOperand;
return success();
}
-static void printDimsModifier(OpAsmPrinter &p, Operation *op,
- IntegerAttr dimsAttr, OperandRange values,
- TypeRange types) {
- // Print dims modifier if present
+//===----------------------------------------------------------------------===//
+// Helper: Print dims modifier with values
+//===----------------------------------------------------------------------===//
+// Prints: dims(N): values : type (single type for all values)
+static void printDimsModifierWithValues(OpAsmPrinter &p, IntegerAttr dimsAttr,
+ OperandRange values, TypeRange types) {
if (dimsAttr) {
p << "dims(" << dimsAttr.getInt() << "): ";
}
- // Print operands
p.printOperands(values);
- // Print types
+ // Print single type
p << " : ";
- llvm::interleaveComma(types, p);
+ if (!types.empty()) {
+ p << types.front();
+ }
+}
+
+static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr, OperandRange values,
+ TypeRange types, Value lowerBound,
+ Type lowerBoundType, Value upperBound,
+ Type upperBoundType) {
+ if (!values.empty()) {
+ // Multidimensional: dims(N): values : type
+ printDimsModifierWithValues(p, dimsAttr, values, types);
+ } else if (upperBound) {
+ if (lowerBound) {
+ // Both bounds: lower : type to upper : type
+ p.printOperand(lowerBound);
+ p << " : " << lowerBoundType << " to ";
+ p.printOperand(upperBound);
+ p << " : " << upperBoundType;
+ } else {
+ // Upper only: to upper : type
+ p << " to ";
+ p.printOperand(upperBound);
+ p << " : " << upperBoundType;
+ }
+ }
}
#define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 62619f07d6573..836cac9a53707 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1451,7 +1451,82 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_mismatch() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ // expected-error @below {{num_teams dims(3) specified but 2 values provided}}
+ "omp.teams" (%v0, %v1) ({
+ omp.terminator
+ }) {num_teams_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_with_bounds() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ %lb = arith.constant 3 : i32
+ %ub = arith.constant 4 : i32
+ // expected-error @below {{num_teams with dims modifier cannot be used together with lower/upper bounds (unidimensional style)}}
+ "omp.teams" (%v0, %v1, %lb, %ub) ({
+ omp.terminator
+ }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,1,1,0,0,0>} : (i32, i32, i32, i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_values_without_dims() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i32
+ // expected-error @below {{num_teams values can only be specified with dims modifier}}
+ "omp.teams" (%v0, %v1) ({
+ omp.terminator
+ }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i32) -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_no_values() {
+ omp.target {
+ // expected-error @below {{num_teams dims modifier requires values to be specified}}
+ "omp.teams" () ({
+ omp.terminator
+ }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> ()
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_type_mismatch() {
+ omp.target {
+ %v0 = arith.constant 1 : i32
+ %v1 = arith.constant 2 : i64
+ // expected-error @below {{num_teams dims modifier requires all values to have the same type}}
+ "omp.teams" (%v0, %v1) ({
+ omp.terminator
+ }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i64) -> ()
omp.terminator
}
return
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index c1bd3cd0fa446..24c2383f3c3aa 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1109,8 +1109,8 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
omp.terminator
}
- // CHECK: omp.teams num_teams_multi_dim(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32)
- omp.teams num_teams_multi_dim(dims(3): %lb, %ub, %ub : i32, i32, i32) {
+ // CHECK: omp.teams num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+ omp.teams num_teams(dims(3): %lb, %ub, %ub : i32) {
// CHECK: omp.terminator
omp.terminator
}
>From 19dc8d148220d782e8f741d781ec21c1f942a59f Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 10 Dec 2025 15:40:37 +0530
Subject: [PATCH 6/9] Mark mlir->llvmir for num_teams with dims as NYI
---
.../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 12 ++++++++++++
1 file changed, 12 insertions(+)
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 65425e29bc148..a31c710320266 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2024,6 +2024,10 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(*op)))
return failure();
+ if (op.getNumTeamsDims().has_value() || !op.getNumTeamsValues().empty()) {
+ return op.emitError("Lowering of num_teams with dims modifier is NYI.");
+ }
+
DenseMap<Value, llvm::Value *> reductionVariableMap;
unsigned numReductionVars = op.getNumReductionVars();
SmallVector<omp::DeclareReductionOp> reductionDecls;
@@ -6035,6 +6039,10 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
for (Operation *user : blockArg.getUsers()) {
llvm::TypeSwitch<Operation *>(user)
.Case([&](omp::TeamsOp teamsOp) {
+ // num_teams dims and values are not yet supported
+ assert(!teamsOp.getNumTeamsDims().has_value() &&
+ teamsOp.getNumTeamsValues().empty() &&
+ "Lowering of num_teams with dims modifier is NYI.");
if (teamsOp.getNumTeamsLower() == blockArg)
numTeamsLower = hostEvalVar;
else if (teamsOp.getNumTeamsUpper() == blockArg)
@@ -6157,6 +6165,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// host_eval, but instead evaluated prior to entry to the region. This
// ensures values are mapped and available inside of the target region.
if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
+ // num_teams dims and values are not yet supported
+ assert(!teamsOp.getNumTeamsDims().has_value() &&
+ teamsOp.getNumTeamsValues().empty() &&
+ "Lowering of num_teams with dims modifier is NYI.");
numTeamsLower = teamsOp.getNumTeamsLower();
numTeamsUpper = teamsOp.getNumTeamsUpper();
threadLimit = teamsOp.getThreadLimit();
>From 6d53bb2cdbcba9e740ea0e6dffaccf656e0fd6f3 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 17:08:24 +0530
Subject: [PATCH 7/9] few more fixes
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 35 ++---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 142 ++++++++++--------
.../OpenMP/OpenMPToLLVMIRTranslation.cpp | 8 +-
mlir/test/Dialect/OpenMP/invalid.mlir | 18 +--
4 files changed, 108 insertions(+), 95 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 0b17a62c88d92..2beb6690f1736 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -974,15 +974,15 @@ class OpenMP_NumTeamsClauseSkip<
> : OpenMP_Clause<traits, arguments, assemblyFormat, description,
extraClassDeclaration> {
let arguments = (ins
- ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
- Variadic<AnyInteger>:$num_teams_values,
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_num_dims,
+ Variadic<AnyInteger>:$num_teams_dims_values,
Optional<AnyInteger>:$num_teams_lower,
Optional<AnyInteger>:$num_teams_upper
);
let optAssemblyFormat = [{
`num_teams` `(` custom<NumTeamsClause>(
- $num_teams_dims, $num_teams_values, type($num_teams_values),
+ $num_teams_num_dims, $num_teams_dims_values, type($num_teams_dims_values),
$num_teams_lower, type($num_teams_lower),
$num_teams_upper, type($num_teams_upper)
) `)`
@@ -993,7 +993,7 @@ class OpenMP_NumTeamsClauseSkip<
construct on which it appears.
With dims modifier: (OpenMP 6.1 requirement)
- - Uses `num_teams_dims` (dimension count) and `num_teams_values` (upper bounds list)
+ - Uses `num_teams_num_dims` (dimension count) and `num_teams_dims_values` (upper bounds list)
- Specifies upper bounds for each dimension (all must have same type)
- Format: `num_teams(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
- Example: `num_teams(dims(3): %ub0, %ub1, %ub2 : i32)`
@@ -1007,28 +1007,23 @@ class OpenMP_NumTeamsClauseSkip<
let extraClassDeclaration = [{
/// Returns true if the dims modifier is explicitly present
- bool hasDimsModifier() {
- return getNumTeamsDims().has_value();
+ bool hasNumTeamsDimsModifier() {
+ return getNumTeamsNumDims().has_value() && getNumTeamsNumDims().value();
}
- /// Returns the number of dimensions specified by dims modifier
- unsigned getNumDimensions() {
- if (!hasDimsModifier())
+ /// Returns the number of dimensions specified for num_teams
+ unsigned getNumTeamsDimsCount() {
+ if (!hasNumTeamsDimsModifier())
return 1;
- return static_cast<unsigned>(*getNumTeamsDims());
- }
-
- /// Returns all dimension values as an operand range
- ::mlir::OperandRange getDimensionValues() {
- return getNumTeamsValues();
+ return static_cast<unsigned>(*getNumTeamsNumDims());
}
/// Returns the value for a specific dimension index
- /// Index must be less than getNumDimensions()
- ::mlir::Value getDimensionValue(unsigned index) {
- assert(index < getDimensionValues().size() &&
- "Dimension index out of bounds");
- return getDimensionValues()[index];
+ /// Index must be less than getNumTeamsDimsCount()
+ ::mlir::Value getNumTeamsDimsValue(unsigned index) {
+ assert(index < getNumTeamsDimsCount() &&
+ "Num teams dims index out of bounds");
+ return getNumTeamsDimsValues()[index];
}
}];
}
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 2d95f743fcd08..c7ea15c8616a2 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2197,6 +2197,40 @@ LogicalResult TargetUpdateOp::verify() {
// TargetOp
//===----------------------------------------------------------------------===//
+// Helper: Verify dims modifier
+static LogicalResult verifyDimsModifier(Operation *op,
+ std::optional<IntegerAttr> numDimsAttr,
+ OperandRange dimsValues) {
+ if (numDimsAttr.has_value() && numDimsAttr.value()) {
+ if (dimsValues.empty()) {
+ return op->emitError("dims modifier requires values to be specified");
+ }
+
+ if (dimsValues.size() != static_cast<size_t>(numDimsAttr->getInt())) {
+ return op->emitError("dims(")
+ << numDimsAttr->getInt() << ") specified but " << dimsValues.size()
+ << " values provided";
+ }
+
+ if (!dimsValues.empty()) {
+ Type firstType = dimsValues.front().getType();
+ for (auto value : dimsValues) {
+ if (value.getType() != firstType) {
+ return op->emitError(
+ "dims modifier requires all values to have the same type");
+ }
+ }
+ }
+ return success();
+ } else {
+ if (!dimsValues.empty()) {
+ return op->emitError(
+ "dims values can only be specified with dims modifier");
+ }
+ }
+ return success();
+}
+
void TargetOp::build(OpBuilder &builder, OperationState &state,
const TargetOperands &clauses) {
MLIRContext *ctx = builder.getContext();
@@ -2624,15 +2658,49 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
const TeamsOperands &clauses) {
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
- TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
- clauses.ifExpr, clauses.numTeamsDims, clauses.numTeamsValues,
- clauses.numTeamsLower, clauses.numTeamsUpper,
- /*private_vars=*/{}, /*private_syms=*/nullptr,
- /*private_needs_barrier=*/nullptr, clauses.reductionMod,
- clauses.reductionVars,
- makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
- makeArrayAttr(ctx, clauses.reductionSyms),
- clauses.threadLimit);
+ TeamsOp::build(
+ builder, state, clauses.allocateVars, clauses.allocatorVars,
+ clauses.ifExpr, clauses.numTeamsNumDims, clauses.numTeamsDimsValues,
+ clauses.numTeamsLower, clauses.numTeamsUpper,
+ /*private_vars=*/{}, /*private_syms=*/nullptr,
+ /*private_needs_barrier=*/nullptr, clauses.reductionMod,
+ clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+ makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimit);
+}
+
+// Helper: Verify num_teams clause
+static LogicalResult
+verifyNumTeamsClause(Operation *op, std::optional<IntegerAttr> numTeamsNumDims,
+ OperandRange numTeamsDimsValues, Value numTeamsLower,
+ Value numTeamsUpper) {
+ bool hasDimsModifier = numTeamsNumDims.has_value() && numTeamsNumDims.value();
+
+ // Cannot use both dims modifier and unidimensional style
+ if (hasDimsModifier && (numTeamsLower || numTeamsUpper)) {
+ return op->emitError(
+ "num_teams with dims modifier cannot be used together with "
+ "lower/upper bounds");
+ }
+
+ // With dims modifier
+ if (failed(verifyDimsModifier(op, numTeamsNumDims, numTeamsDimsValues)))
+ return failure();
+
+ // Without dims modifier
+ if (!hasDimsModifier) {
+ if (numTeamsLower) {
+ if (!numTeamsUpper)
+ return op->emitError(
+ "expected num_teams upper bound to be defined if the "
+ "lower bound is defined");
+ if (numTeamsLower.getType() != numTeamsUpper.getType())
+ return op->emitError(
+ "expected num_teams upper bound and lower bound to be "
+ "the same type");
+ }
+ }
+ return success();
}
LogicalResult TeamsOp::verify() {
@@ -2648,58 +2716,10 @@ LogicalResult TeamsOp::verify() {
"in any OpenMP dialect operations");
// Check for num_teams clause restrictions
- auto numTeamsDims = getNumTeamsDims();
- auto numTeamsValues = getNumTeamsValues();
- auto numTeamsLower = getNumTeamsLower();
- auto numTeamsUpper = getNumTeamsUpper();
-
- // Cannot use both dims modifier and unidimensional style
- if (numTeamsDims.has_value() && (numTeamsLower || numTeamsUpper)) {
- return emitError(
- "num_teams with dims modifier cannot be used together with "
- "lower/upper bounds (unidimensional style)");
- }
-
- // With dims modifier (multidimensional)
- if (numTeamsDims.has_value()) {
- if (numTeamsValues.empty()) {
- return emitError(
- "num_teams dims modifier requires values to be specified");
- }
-
- if (numTeamsValues.size() != static_cast<size_t>(*numTeamsDims)) {
- return emitError("num_teams dims(")
- << *numTeamsDims << ") specified but " << numTeamsValues.size()
- << " values provided";
- }
-
- // All values must have the same type
- if (!numTeamsValues.empty()) {
- Type firstType = numTeamsValues.front().getType();
- for (auto value : numTeamsValues) {
- if (value.getType() != firstType) {
- return emitError(
- "num_teams dims modifier requires all values to have "
- "the same type");
- }
- }
- }
- } else {
- // Without dims modifier
- if (!numTeamsValues.empty()) {
- return emitError(
- "num_teams values can only be specified with dims modifier");
- }
-
- if (numTeamsLower) {
- if (!numTeamsUpper)
- return emitError("expected num_teams upper bound to be defined if the "
- "lower bound is defined");
- if (numTeamsLower.getType() != numTeamsUpper.getType())
- return emitError("expected num_teams upper bound and lower bound to be "
- "the same type");
- }
- }
+ if (failed(verifyNumTeamsClause(
+ op, this->getNumTeamsNumDimsAttr(), this->getNumTeamsDimsValues(),
+ this->getNumTeamsLower(), this->getNumTeamsUpper())))
+ return failure();
// Check for allocate clause restrictions
if (getAllocateVars().size() != getAllocatorVars().size())
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index a31c710320266..14cea49cc722d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2024,7 +2024,7 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
if (failed(checkImplementationStatus(*op)))
return failure();
- if (op.getNumTeamsDims().has_value() || !op.getNumTeamsValues().empty()) {
+ if (op.hasNumTeamsDimsModifier()) {
return op.emitError("Lowering of num_teams with dims modifier is NYI.");
}
@@ -6040,8 +6040,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
llvm::TypeSwitch<Operation *>(user)
.Case([&](omp::TeamsOp teamsOp) {
// num_teams dims and values are not yet supported
- assert(!teamsOp.getNumTeamsDims().has_value() &&
- teamsOp.getNumTeamsValues().empty() &&
+ assert(!teamsOp.hasNumTeamsDimsModifier() &&
"Lowering of num_teams with dims modifier is NYI.");
if (teamsOp.getNumTeamsLower() == blockArg)
numTeamsLower = hostEvalVar;
@@ -6166,8 +6165,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// ensures values are mapped and available inside of the target region.
if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
// num_teams dims and values are not yet supported
- assert(!teamsOp.getNumTeamsDims().has_value() &&
- teamsOp.getNumTeamsValues().empty() &&
+ assert(!teamsOp.hasNumTeamsDimsModifier() &&
"Lowering of num_teams with dims modifier is NYI.");
numTeamsLower = teamsOp.getNumTeamsLower();
numTeamsUpper = teamsOp.getNumTeamsUpper();
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 836cac9a53707..dd367aba8da27 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1463,10 +1463,10 @@ func.func @omp_teams_num_teams_dims_mismatch() {
omp.target {
%v0 = arith.constant 1 : i32
%v1 = arith.constant 2 : i32
- // expected-error @below {{num_teams dims(3) specified but 2 values provided}}
+ // expected-error @below {{dims(3) specified but 2 values provided}}
"omp.teams" (%v0, %v1) ({
omp.terminator
- }) {num_teams_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i32) -> ()
+ }) {num_teams_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i32) -> ()
omp.terminator
}
return
@@ -1480,10 +1480,10 @@ func.func @omp_teams_num_teams_dims_with_bounds() {
%v1 = arith.constant 2 : i32
%lb = arith.constant 3 : i32
%ub = arith.constant 4 : i32
- // expected-error @below {{num_teams with dims modifier cannot be used together with lower/upper bounds (unidimensional style)}}
+ // expected-error @below {{num_teams with dims modifier cannot be used together with lower/upper bounds}}
"omp.teams" (%v0, %v1, %lb, %ub) ({
omp.terminator
- }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,1,1,0,0,0>} : (i32, i32, i32, i32) -> ()
+ }) {num_teams_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,1,1,0,0,0>} : (i32, i32, i32, i32) -> ()
omp.terminator
}
return
@@ -1495,7 +1495,7 @@ func.func @omp_teams_num_teams_values_without_dims() {
omp.target {
%v0 = arith.constant 1 : i32
%v1 = arith.constant 2 : i32
- // expected-error @below {{num_teams values can only be specified with dims modifier}}
+ // expected-error @below {{dims values can only be specified with dims modifier}}
"omp.teams" (%v0, %v1) ({
omp.terminator
}) {operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i32) -> ()
@@ -1508,10 +1508,10 @@ func.func @omp_teams_num_teams_values_without_dims() {
func.func @omp_teams_num_teams_dims_no_values() {
omp.target {
- // expected-error @below {{num_teams dims modifier requires values to be specified}}
+ // expected-error @below {{dims modifier requires values to be specified}}
"omp.teams" () ({
omp.terminator
- }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> ()
+ }) {num_teams_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> ()
omp.terminator
}
return
@@ -1523,10 +1523,10 @@ func.func @omp_teams_num_teams_dims_type_mismatch() {
omp.target {
%v0 = arith.constant 1 : i32
%v1 = arith.constant 2 : i64
- // expected-error @below {{num_teams dims modifier requires all values to have the same type}}
+ // expected-error @below {{dims modifier requires all values to have the same type}}
"omp.teams" (%v0, %v1) ({
omp.terminator
- }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i64) -> ()
+ }) {num_teams_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i64) -> ()
omp.terminator
}
return
>From 01cb639c0da401e78f51c6208ab92b9e648e5340 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 16 Dec 2025 19:58:14 +0530
Subject: [PATCH 8/9] allow single dims_values
---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index c7ea15c8616a2..03897c4c97df8 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2223,7 +2223,7 @@ static LogicalResult verifyDimsModifier(Operation *op,
}
return success();
} else {
- if (!dimsValues.empty()) {
+ if (dimsValues.size() > 1) {
return op->emitError(
"dims values can only be specified with dims modifier");
}
>From d5c7b19b8dc2ff3e66136109543e600846fdecf7 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 19 Dec 2025 09:51:21 +0530
Subject: [PATCH 9/9] Make dims_values as IntLikeType
---
mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td | 2 +-
.../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 11 +++++++----
2 files changed, 8 insertions(+), 5 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 2beb6690f1736..b949e2629a095 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -975,7 +975,7 @@ class OpenMP_NumTeamsClauseSkip<
extraClassDeclaration> {
let arguments = (ins
ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_num_dims,
- Variadic<AnyInteger>:$num_teams_dims_values,
+ Variadic<IntLikeType>:$num_teams_dims_values,
Optional<AnyInteger>:$num_teams_lower,
Optional<AnyInteger>:$num_teams_upper
);
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 14cea49cc722d..965a399fd653f 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2025,7 +2025,8 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
return failure();
if (op.hasNumTeamsDimsModifier()) {
- return op.emitError("Lowering of num_teams with dims modifier is NYI.");
+ return op.emitError(
+ "Lowering of num_teams with dims modifier is not yet implemented.");
}
DenseMap<Value, llvm::Value *> reductionVariableMap;
@@ -6041,7 +6042,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
.Case([&](omp::TeamsOp teamsOp) {
// num_teams dims and values are not yet supported
assert(!teamsOp.hasNumTeamsDimsModifier() &&
- "Lowering of num_teams with dims modifier is NYI.");
+ "Lowering of num_teams with dims modifier is not yet "
+ "implemented.");
if (teamsOp.getNumTeamsLower() == blockArg)
numTeamsLower = hostEvalVar;
else if (teamsOp.getNumTeamsUpper() == blockArg)
@@ -6165,8 +6167,9 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
// ensures values are mapped and available inside of the target region.
if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
// num_teams dims and values are not yet supported
- assert(!teamsOp.hasNumTeamsDimsModifier() &&
- "Lowering of num_teams with dims modifier is NYI.");
+ assert(
+ !teamsOp.hasNumTeamsDimsModifier() &&
+ "Lowering of num_teams with dims modifier is not yet implemented.");
numTeamsLower = teamsOp.getNumTeamsLower();
numTeamsUpper = teamsOp.getNumTeamsUpper();
threadLimit = teamsOp.getThreadLimit();
More information about the Mlir-commits
mailing list