[Mlir-commits] [mlir] [OpenMP][MLIR] Add num_teams clause with dims modifier support (PR #169883)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Fri Nov 28 00:54:52 PST 2025
https://github.com/skc7 created https://github.com/llvm/llvm-project/pull/169883
This is WIP PR for support for openmp 6.2 feature `num_teams` with dims modifier support.
To not break the current code, named the clause as `num_teams_multi_dim`. Will name it back to num_teams once the ompIRBuilder supports creating `teams` with dims modifier argument.
>From f82593248ecb6218345ce03a59271a4589c3c17a Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 28 Nov 2025 13:37:14 +0530
Subject: [PATCH 1/2] [OpenMP][MLIR] Add num_teams clause with dims modifier
support
---
.../mlir/Dialect/OpenMP/OpenMPClauses.td | 72 +++++++++++++++++++
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 3 +-
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 64 +++++++++++++++++
mlir/test/Dialect/OpenMP/invalid.mlir | 4 +-
mlir/test/Dialect/OpenMP/ops.mlir | 6 ++
5 files changed, 146 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 8e43c4284d078..ac452177b5cf7 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1532,4 +1532,76 @@ class OpenMP_UseDevicePtrClauseSkip<
def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
+//===----------------------------------------------------------------------===//
+// V6.2: Multidimensional `num_teams` clause with dims modifier
+//===----------------------------------------------------------------------===//
+
+class OpenMP_NumTeamsMultiDimClauseSkip<
+ bit traits = false, bit arguments = false, bit assemblyFormat = false,
+ bit description = false, bit extraClassDeclaration = false
+ > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+ extraClassDeclaration> {
+ let arguments = (ins
+ ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
+ Variadic<AnyInteger>:$num_teams_values
+ );
+
+ let optAssemblyFormat = [{
+ `num_teams_multi_dim` `(` custom<NumTeamsMultiDimClause>($num_teams_dims,
+ $num_teams_values,
+ type($num_teams_values)) `)`
+ }];
+
+ let description = [{
+ The `num_teams_multi_dim` clause with dims modifier support specifies the limit on
+ the number of teams to be created in a multidimensional team space.
+
+ The dims modifier for the num_teams_multi_dim clause specifies the number of
+ dimensions for the league space (team space) that the clause arranges.
+ The dimensions argument in the dims modifier specifies the number of
+ dimensions and determines the length of the list argument. The list items
+ are specified in ascending order according to the ordinal number of the
+ dimensions (dimension 0, 1, 2, ..., N-1).
+
+ - If `dims` is not specified: The space is unidimensional (1D) with a single value
+ - If `dims(1)` is specified: The space is explicitly unidimensional (1D)
+ - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D)
+
+ **Examples:**
+ - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a
+ 3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2.
+ - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt.
+ }];
+
+ let extraClassDeclaration = [{
+ /// Returns true if the dims modifier is explicitly present
+ bool hasDimsModifier() {
+ return getNumTeamsDims().has_value();
+ }
+
+ /// Returns the number of dimensions specified by dims modifier
+ /// Returns 1 if dims modifier is not present (unidimensional by default)
+ unsigned getNumDimensions() {
+ if (!hasDimsModifier())
+ return 1;
+ return static_cast<unsigned>(*getNumTeamsDims());
+ }
+
+ /// Returns all dimension values as an operand range
+ ::mlir::OperandRange getDimensionValues() {
+ return getNumTeamsValues();
+ }
+
+ /// Returns the value for a specific dimension index
+ /// Index must be less than getNumDimensions()
+ ::mlir::Value getDimensionValue(unsigned index) {
+ assert(index < getDimensionValues().size() &&
+ "Dimension index out of bounds");
+ return getDimensionValues()[index];
+ }
+ }];
+}
+
+def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>;
+
#endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index bbfe805eefe48..ea440cf924a95 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,7 +241,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
], clauses = [
OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
- OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
+ OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
+ OpenMP_ThreadLimitClause
], singleRegion = true> {
let summary = "teams construct";
let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 0d6b2870c625a..b9fbd38aebadb 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2621,6 +2621,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
// TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+ clauses.numTeamsDims, clauses.numTeamsValues,
/*private_vars=*/{}, /*private_syms=*/nullptr,
/*private_needs_barrier=*/nullptr, clauses.reductionMod,
clauses.reductionVars,
@@ -4453,6 +4454,69 @@ LogicalResult WorkdistributeOp::verify() {
return success();
}
+//===----------------------------------------------------------------------===//
+// Parser and printer for NumTeamsMultiDim Clause (with dims modifier)
+//===----------------------------------------------------------------------===//
+// num_teams_multidim ::= `num_teams` `(` [`dims` `(` dim-count `)` `:`] values
+// `)` Example: num_teams(dims(3): %v0, %v1, %v2 : i32, i32, i32) Or:
+// num_teams(%v : i32)
+static ParseResult parseNumTeamsMultiDimClause(
+ OpAsmParser &parser, IntegerAttr &dimsAttr,
+ SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+ SmallVectorImpl<Type> &types) {
+ std::optional<int64_t> dims;
+ // Try to parse optional dims modifier: dims(N):
+ if (succeeded(parser.parseOptionalKeyword("dims"))) {
+ int64_t dimsValue;
+ if (parser.parseLParen() || parser.parseInteger(dimsValue) ||
+ parser.parseRParen() || parser.parseColon()) {
+ return failure();
+ }
+ dims = dimsValue;
+ }
+ // Parse the operand list
+ if (parser.parseOperandList(values))
+ return failure();
+ // Parse colon and types
+ if (parser.parseColon() || parser.parseTypeList(types))
+ return failure();
+
+ // Verify dims matches number of values if specified
+ if (dims.has_value() && values.size() != static_cast<size_t>(*dims)) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "dims(" << *dims << ") specified but " << values.size()
+ << " values provided";
+ }
+
+ // If dims not specified but we have values, it's implicitly unidimensional
+ if (!dims.has_value() && values.size() != 1) {
+ return parser.emitError(parser.getCurrentLocation())
+ << "expected 1 value without dims modifier, got " << values.size();
+ }
+
+ // Convert to IntegerAttr
+ if (dims.has_value()) {
+ dimsAttr = parser.getBuilder().getI64IntegerAttr(*dims);
+ }
+ return success();
+}
+
+static void printNumTeamsMultiDimClause(OpAsmPrinter &p, Operation *op,
+ IntegerAttr dimsAttr,
+ OperandRange values, TypeRange types) {
+ // Print dims modifier if present
+ if (dimsAttr) {
+ p << "dims(" << dimsAttr.getInt() << "): ";
+ }
+
+ // Print operands
+ p.printOperands(values);
+
+ // Print types
+ p << " : ";
+ llvm::interleaveComma(types, p);
+}
+
#define GET_ATTRDEF_CLASSES
#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index af24d969064ab..62619f07d6573 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
// expected-error @below {{expected equal sizes for allocate and allocator variables}}
"omp.teams" (%data_var) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+ }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
omp.terminator
}
return
@@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
// expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
"omp.teams" (%lb) ({
omp.terminator
- }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0>} : (i32) -> ()
+ }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
omp.terminator
}
return
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index ac29e20907b55..bb154fad12742 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1108,6 +1108,12 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
omp.terminator
}
+ // CHECK: omp.teams num_teams_multi_dim(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32)
+ omp.teams num_teams_multi_dim(dims(3): %lb, %ub, %ub : i32, i32, i32) {
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+
// Test if.
// CHECK: omp.teams if(%{{.+}})
omp.teams if(%if_cond) {
>From 20d58da6af969212d6ee751f0d22a9a8942aa51a Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 28 Nov 2025 14:20:39 +0530
Subject: [PATCH 2/2] fix comment
---
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 5 ++---
1 file changed, 2 insertions(+), 3 deletions(-)
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index b9fbd38aebadb..b26e22e9a781c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4457,9 +4457,8 @@ LogicalResult WorkdistributeOp::verify() {
//===----------------------------------------------------------------------===//
// Parser and printer for NumTeamsMultiDim Clause (with dims modifier)
//===----------------------------------------------------------------------===//
-// num_teams_multidim ::= `num_teams` `(` [`dims` `(` dim-count `)` `:`] values
-// `)` Example: num_teams(dims(3): %v0, %v1, %v2 : i32, i32, i32) Or:
-// num_teams(%v : i32)
+// num_teams_multi_dim(dims(3): %v0, %v1, %v2 : i32, i32, i32) Or:
+// num_teams_multi_dim(%v : i32)
static ParseResult parseNumTeamsMultiDimClause(
OpAsmParser &parser, IntegerAttr &dimsAttr,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
More information about the Mlir-commits
mailing list