[Mlir-commits] [mlir] [OpenMP][MLIR] Add num_teams clause with dims modifier support (PR #169883)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Wed Dec 10 02:15:07 PST 2025


https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/169883

>From f82593248ecb6218345ce03a59271a4589c3c17a Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 28 Nov 2025 13:37:14 +0530
Subject: [PATCH 1/6] [OpenMP][MLIR] Add num_teams clause with dims modifier
 support

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 72 +++++++++++++++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  3 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 64 +++++++++++++++++
 mlir/test/Dialect/OpenMP/invalid.mlir         |  4 +-
 mlir/test/Dialect/OpenMP/ops.mlir             |  6 ++
 5 files changed, 146 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 8e43c4284d078..ac452177b5cf7 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1532,4 +1532,76 @@ class OpenMP_UseDevicePtrClauseSkip<
 
 def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V6.2: Multidimensional `num_teams` clause with dims modifier
+//===----------------------------------------------------------------------===//
+
+class OpenMP_NumTeamsMultiDimClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+            extraClassDeclaration> {
+  let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
+    Variadic<AnyInteger>:$num_teams_values
+  );
+
+  let optAssemblyFormat = [{
+    `num_teams_multi_dim` `(` custom<NumTeamsMultiDimClause>($num_teams_dims,
+                                                    $num_teams_values,
+                                                    type($num_teams_values)) `)`
+  }];
+
+  let description = [{
+    The `num_teams_multi_dim` clause with dims modifier support specifies the limit on
+    the number of teams to be created in a multidimensional team space.
+
+    The dims modifier for the num_teams_multi_dim clause specifies the number of
+    dimensions for the league space (team space) that the clause arranges.
+    The dimensions argument in the dims modifier specifies the number of
+    dimensions and determines the length of the list argument. The list items
+    are specified in ascending order according to the ordinal number of the
+    dimensions (dimension 0, 1, 2, ..., N-1).
+
+    - If `dims` is not specified: The space is unidimensional (1D) with a single value
+    - If `dims(1)` is specified: The space is explicitly unidimensional (1D)
+    - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D)
+
+    **Examples:**
+    - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a
+      3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2.
+    - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt.
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasDimsModifier() {
+      return getNumTeamsDims().has_value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    /// Returns 1 if dims modifier is not present (unidimensional by default)
+    unsigned getNumDimensions() {
+      if (!hasDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getNumTeamsDims());
+    }
+
+    /// Returns all dimension values as an operand range
+    ::mlir::OperandRange getDimensionValues() {
+      return getNumTeamsValues();
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumDimensions()
+    ::mlir::Value getDimensionValue(unsigned index) {
+    assert(index < getDimensionValues().size() &&
+        "Dimension index out of bounds");
+      return getDimensionValues()[index];
+    }
+  }];
+}
+
+def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>;
+
 #endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index bbfe805eefe48..ea440cf924a95 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,7 +241,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
     AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
   ], clauses = [
     OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
-    OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
+    OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
+    OpenMP_ThreadLimitClause
   ], singleRegion = true> {
   let summary = "teams construct";
   let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 0d6b2870c625a..b9fbd38aebadb 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2621,6 +2621,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
   // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
                  clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+                 clauses.numTeamsDims, clauses.numTeamsValues,
                  /*private_vars=*/{}, /*private_syms=*/nullptr,
                  /*private_needs_barrier=*/nullptr, clauses.reductionMod,
                  clauses.reductionVars,
@@ -4453,6 +4454,69 @@ LogicalResult WorkdistributeOp::verify() {
   return success();
 }
 
+//===----------------------------------------------------------------------===//
+// Parser and printer for NumTeamsMultiDim Clause (with dims modifier)
+//===----------------------------------------------------------------------===//
+// num_teams_multidim ::= `num_teams` `(` [`dims` `(` dim-count `)` `:`] values
+// `)` Example: num_teams(dims(3): %v0, %v1, %v2 : i32, i32, i32) Or:
+// num_teams(%v : i32)
+static ParseResult parseNumTeamsMultiDimClause(
+    OpAsmParser &parser, IntegerAttr &dimsAttr,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    SmallVectorImpl<Type> &types) {
+  std::optional<int64_t> dims;
+  // Try to parse optional dims modifier: dims(N):
+  if (succeeded(parser.parseOptionalKeyword("dims"))) {
+    int64_t dimsValue;
+    if (parser.parseLParen() || parser.parseInteger(dimsValue) ||
+        parser.parseRParen() || parser.parseColon()) {
+      return failure();
+    }
+    dims = dimsValue;
+  }
+  // Parse the operand list
+  if (parser.parseOperandList(values))
+    return failure();
+  // Parse colon and types
+  if (parser.parseColon() || parser.parseTypeList(types))
+    return failure();
+
+  // Verify dims matches number of values if specified
+  if (dims.has_value() && values.size() != static_cast<size_t>(*dims)) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "dims(" << *dims << ") specified but " << values.size()
+           << " values provided";
+  }
+
+  // If dims not specified but we have values, it's implicitly unidimensional
+  if (!dims.has_value() && values.size() != 1) {
+    return parser.emitError(parser.getCurrentLocation())
+           << "expected 1 value without dims modifier, got " << values.size();
+  }
+
+  // Convert to IntegerAttr
+  if (dims.has_value()) {
+    dimsAttr = parser.getBuilder().getI64IntegerAttr(*dims);
+  }
+  return success();
+}
+
+static void printNumTeamsMultiDimClause(OpAsmPrinter &p, Operation *op,
+                                        IntegerAttr dimsAttr,
+                                        OperandRange values, TypeRange types) {
+  // Print dims modifier if present
+  if (dimsAttr) {
+    p << "dims(" << dimsAttr.getInt() << "): ";
+  }
+
+  // Print operands
+  p.printOperands(values);
+
+  // Print types
+  p << " : ";
+  llvm::interleaveComma(types, p);
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index af24d969064ab..62619f07d6573 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
     // expected-error @below {{expected equal sizes for allocate and allocator variables}}
     "omp.teams" (%data_var) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+    }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
     omp.terminator
   }
   return
@@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
     // expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
     "omp.teams" (%lb) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0>} : (i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
     omp.terminator
   }
   return
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index ac29e20907b55..bb154fad12742 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1108,6 +1108,12 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
     omp.terminator
   }
 
+  // CHECK: omp.teams num_teams_multi_dim(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32)
+  omp.teams num_teams_multi_dim(dims(3): %lb, %ub, %ub : i32, i32, i32) {
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+
   // Test if.
   // CHECK: omp.teams if(%{{.+}})
   omp.teams if(%if_cond) {

>From 20d58da6af969212d6ee751f0d22a9a8942aa51a Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 28 Nov 2025 14:20:39 +0530
Subject: [PATCH 2/6] fix comment

---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 5 ++---
 1 file changed, 2 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index b9fbd38aebadb..b26e22e9a781c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4457,9 +4457,8 @@ LogicalResult WorkdistributeOp::verify() {
 //===----------------------------------------------------------------------===//
 // Parser and printer for NumTeamsMultiDim Clause (with dims modifier)
 //===----------------------------------------------------------------------===//
-// num_teams_multidim ::= `num_teams` `(` [`dims` `(` dim-count `)` `:`] values
-// `)` Example: num_teams(dims(3): %v0, %v1, %v2 : i32, i32, i32) Or:
-// num_teams(%v : i32)
+// num_teams_multi_dim(dims(3): %v0, %v1, %v2 : i32, i32, i32) Or:
+// num_teams_multi_dim(%v : i32)
 static ParseResult parseNumTeamsMultiDimClause(
     OpAsmParser &parser, IntegerAttr &dimsAttr,
     SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,

>From 622e7bd93a3c09cc90bf1e1d607c2ba878b5d749 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 1 Dec 2025 12:56:28 +0530
Subject: [PATCH 3/6] Comments fix

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td | 2 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp      | 3 ++-
 2 files changed, 3 insertions(+), 2 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index ac452177b5cf7..64d2ba8c1503f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1533,7 +1533,7 @@ class OpenMP_UseDevicePtrClauseSkip<
 def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
 
 //===----------------------------------------------------------------------===//
-// V6.2: Multidimensional `num_teams` clause with dims modifier
+// V6.1: Multidimensional `num_teams` clause with dims modifier
 //===----------------------------------------------------------------------===//
 
 class OpenMP_NumTeamsMultiDimClauseSkip<
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index b26e22e9a781c..8705ea6f220c3 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4490,7 +4490,8 @@ static ParseResult parseNumTeamsMultiDimClause(
   // If dims not specified but we have values, it's implicitly unidimensional
   if (!dims.has_value() && values.size() != 1) {
     return parser.emitError(parser.getCurrentLocation())
-           << "expected 1 value without dims modifier, got " << values.size();
+           << "expected 1 value without dims modifier, but got "
+           << values.size() << " values";
   }
 
   // Convert to IntegerAttr

>From 17832c12abdd0fa931362fe4ef0feba956683997 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Mon, 1 Dec 2025 15:02:21 +0530
Subject: [PATCH 4/6] Use DimsModifier for custom assembly parser and printer

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  4 ++--
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 20 +++++++++----------
 2 files changed, 12 insertions(+), 12 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 64d2ba8c1503f..d26b3cfb7a86d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1533,7 +1533,7 @@ class OpenMP_UseDevicePtrClauseSkip<
 def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
 
 //===----------------------------------------------------------------------===//
-// V6.1: Multidimensional `num_teams` clause with dims modifier
+// V6.1: `num_teams` clause with dims modifier
 //===----------------------------------------------------------------------===//
 
 class OpenMP_NumTeamsMultiDimClauseSkip<
@@ -1547,7 +1547,7 @@ class OpenMP_NumTeamsMultiDimClauseSkip<
   );
 
   let optAssemblyFormat = [{
-    `num_teams_multi_dim` `(` custom<NumTeamsMultiDimClause>($num_teams_dims,
+    `num_teams_multi_dim` `(` custom<DimsModifier>($num_teams_dims,
                                                     $num_teams_values,
                                                     type($num_teams_values)) `)`
   }];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 8705ea6f220c3..66b36caccada3 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4455,14 +4455,14 @@ LogicalResult WorkdistributeOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// Parser and printer for NumTeamsMultiDim Clause (with dims modifier)
+// Parser and printer for Clauses with dims modifier
 //===----------------------------------------------------------------------===//
-// num_teams_multi_dim(dims(3): %v0, %v1, %v2 : i32, i32, i32) Or:
-// num_teams_multi_dim(%v : i32)
-static ParseResult parseNumTeamsMultiDimClause(
-    OpAsmParser &parser, IntegerAttr &dimsAttr,
-    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-    SmallVectorImpl<Type> &types) {
+// clause_name(dims(3): %v0, %v1, %v2 : i32, i32, i32)
+// clause_name(%v : i32)
+static ParseResult
+parseDimsModifier(OpAsmParser &parser, IntegerAttr &dimsAttr,
+                  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                  SmallVectorImpl<Type> &types) {
   std::optional<int64_t> dims;
   // Try to parse optional dims modifier: dims(N):
   if (succeeded(parser.parseOptionalKeyword("dims"))) {
@@ -4501,9 +4501,9 @@ static ParseResult parseNumTeamsMultiDimClause(
   return success();
 }
 
-static void printNumTeamsMultiDimClause(OpAsmPrinter &p, Operation *op,
-                                        IntegerAttr dimsAttr,
-                                        OperandRange values, TypeRange types) {
+static void printDimsModifier(OpAsmPrinter &p, Operation *op,
+                              IntegerAttr dimsAttr, OperandRange values,
+                              TypeRange types) {
   // Print dims modifier if present
   if (dimsAttr) {
     p << "dims(" << dimsAttr.getInt() << "): ";

>From 885c344014f979b073ef707ca1fccf63600bdec2 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 10 Dec 2025 13:13:17 +0530
Subject: [PATCH 5/6] use dims modifer in main num_teams clause itself instead
 of creating new clause

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 128 +++++------
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |   3 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 204 ++++++++++++++----
 mlir/test/Dialect/OpenMP/invalid.mlir         |  77 ++++++-
 mlir/test/Dialect/OpenMP/ops.mlir             |   4 +-
 5 files changed, 284 insertions(+), 132 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index d26b3cfb7a86d..1b44873ea99b1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -974,22 +974,62 @@ class OpenMP_NumTeamsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
+    Variadic<AnyInteger>:$num_teams_values,
     Optional<AnyInteger>:$num_teams_lower,
     Optional<AnyInteger>:$num_teams_upper
   );
 
   let optAssemblyFormat = [{
-    `num_teams` `(` ( $num_teams_lower^ `:` type($num_teams_lower) )? `to`
-                      $num_teams_upper `:` type($num_teams_upper) `)`
+    `num_teams` `(` custom<NumTeamsClause>(
+      $num_teams_dims, $num_teams_values, type($num_teams_values),
+      $num_teams_lower, type($num_teams_lower),
+      $num_teams_upper, type($num_teams_upper)
+    ) `)`
   }];
 
   let description = [{
-    The optional `num_teams_upper` and `num_teams_lower` arguments specify the
-    limit on the number of teams to be created. If only the upper bound is
-    specified, it acts as if the lower bound was set to the same value. It is
-    not allowed to set `num_teams_lower` if `num_teams_upper` is not specified.
-    They define a closed range, where both the lower and upper bounds are
-    included.
+    The `num_teams` clause specifies the bounds on the league space formed by the
+    construct on which it appears.
+
+    With dims modifier: (OpenMP 6.1 requirement)
+    - Uses `num_teams_dims` (dimension count) and `num_teams_values` (upper bounds list)
+    - Specifies upper bounds for each dimension (all must have same type)
+    - Format: `num_teams(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+    - Example: `num_teams(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+    Without dims modifier:
+    - Uses `num_teams_upper` and optional `num_teams_lower`
+    - If lower bound not specified, it defaults to upper bound value
+    - Format: `num_teams(lower : type to upper : type)` or `num_teams(to upper : type)`
+    - Example: `num_teams(%lb : i32 to %ub : i32)` or `num_teams(to %ub : i32)`
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasDimsModifier() {
+      return getNumTeamsDims().has_value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    unsigned getNumDimensions() {
+      if (!hasDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getNumTeamsDims());
+    }
+
+    /// Returns all dimension values as an operand range
+    ::mlir::OperandRange getDimensionValues() {
+      return getNumTeamsValues();
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumDimensions()
+    ::mlir::Value getDimensionValue(unsigned index) {
+      assert(index < getDimensionValues().size() &&
+             "Dimension index out of bounds");
+      return getDimensionValues()[index];
+    }
   }];
 }
 
@@ -1532,76 +1572,4 @@ class OpenMP_UseDevicePtrClauseSkip<
 
 def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
 
-//===----------------------------------------------------------------------===//
-// V6.1: `num_teams` clause with dims modifier
-//===----------------------------------------------------------------------===//
-
-class OpenMP_NumTeamsMultiDimClauseSkip<
-    bit traits = false, bit arguments = false, bit assemblyFormat = false,
-    bit description = false, bit extraClassDeclaration = false
-  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
-            extraClassDeclaration> {
-  let arguments = (ins
-    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
-    Variadic<AnyInteger>:$num_teams_values
-  );
-
-  let optAssemblyFormat = [{
-    `num_teams_multi_dim` `(` custom<DimsModifier>($num_teams_dims,
-                                                    $num_teams_values,
-                                                    type($num_teams_values)) `)`
-  }];
-
-  let description = [{
-    The `num_teams_multi_dim` clause with dims modifier support specifies the limit on
-    the number of teams to be created in a multidimensional team space.
-
-    The dims modifier for the num_teams_multi_dim clause specifies the number of
-    dimensions for the league space (team space) that the clause arranges.
-    The dimensions argument in the dims modifier specifies the number of
-    dimensions and determines the length of the list argument. The list items
-    are specified in ascending order according to the ordinal number of the
-    dimensions (dimension 0, 1, 2, ..., N-1).
-
-    - If `dims` is not specified: The space is unidimensional (1D) with a single value
-    - If `dims(1)` is specified: The space is explicitly unidimensional (1D)
-    - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D)
-
-    **Examples:**
-    - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a
-      3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2.
-    - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt.
-  }];
-
-  let extraClassDeclaration = [{
-    /// Returns true if the dims modifier is explicitly present
-    bool hasDimsModifier() {
-      return getNumTeamsDims().has_value();
-    }
-
-    /// Returns the number of dimensions specified by dims modifier
-    /// Returns 1 if dims modifier is not present (unidimensional by default)
-    unsigned getNumDimensions() {
-      if (!hasDimsModifier())
-        return 1;
-      return static_cast<unsigned>(*getNumTeamsDims());
-    }
-
-    /// Returns all dimension values as an operand range
-    ::mlir::OperandRange getDimensionValues() {
-      return getNumTeamsValues();
-    }
-
-    /// Returns the value for a specific dimension index
-    /// Index must be less than getNumDimensions()
-    ::mlir::Value getDimensionValue(unsigned index) {
-    assert(index < getDimensionValues().size() &&
-        "Dimension index out of bounds");
-      return getDimensionValues()[index];
-    }
-  }];
-}
-
-def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>;
-
 #endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index ea440cf924a95..bbfe805eefe48 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,8 +241,7 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
     AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
   ], clauses = [
     OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
-    OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
-    OpenMP_ThreadLimitClause
+    OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
   ], singleRegion = true> {
   let summary = "teams construct";
   let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 66b36caccada3..6f56833b3b76a 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2620,8 +2620,8 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
-                 clauses.numTeamsDims, clauses.numTeamsValues,
+                 clauses.ifExpr, clauses.numTeamsDims, clauses.numTeamsValues,
+                 clauses.numTeamsLower, clauses.numTeamsUpper,
                  /*private_vars=*/{}, /*private_syms=*/nullptr,
                  /*private_needs_barrier=*/nullptr, clauses.reductionMod,
                  clauses.reductionVars,
@@ -2643,14 +2643,57 @@ LogicalResult TeamsOp::verify() {
                      "in any OpenMP dialect operations");
 
   // Check for num_teams clause restrictions
-  if (auto numTeamsLowerBound = getNumTeamsLower()) {
-    auto numTeamsUpperBound = getNumTeamsUpper();
-    if (!numTeamsUpperBound)
-      return emitError("expected num_teams upper bound to be defined if the "
-                       "lower bound is defined");
-    if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
+  auto numTeamsDims = getNumTeamsDims();
+  auto numTeamsValues = getNumTeamsValues();
+  auto numTeamsLower = getNumTeamsLower();
+  auto numTeamsUpper = getNumTeamsUpper();
+
+  // Cannot use both dims modifier and unidimensional style
+  if (numTeamsDims.has_value() && (numTeamsLower || numTeamsUpper)) {
+    return emitError(
+        "num_teams with dims modifier cannot be used together with "
+        "lower/upper bounds (unidimensional style)");
+  }
+
+  // With dims modifier (multidimensional)
+  if (numTeamsDims.has_value()) {
+    if (numTeamsValues.empty()) {
+      return emitError(
+          "num_teams dims modifier requires values to be specified");
+    }
+
+    if (numTeamsValues.size() != static_cast<size_t>(*numTeamsDims)) {
+      return emitError("num_teams dims(")
+             << *numTeamsDims << ") specified but " << numTeamsValues.size()
+             << " values provided";
+    }
+
+    // All values must have the same type
+    if (!numTeamsValues.empty()) {
+      Type firstType = numTeamsValues.front().getType();
+      for (auto value : numTeamsValues) {
+        if (value.getType() != firstType) {
+          return emitError(
+              "num_teams dims modifier requires all values to have "
+              "the same type");
+        }
+      }
+    }
+  } else {
+    // Without dims modifier
+    if (!numTeamsValues.empty()) {
       return emitError(
-          "expected num_teams upper bound and lower bound to be the same type");
+          "num_teams values can only be specified with dims modifier");
+    }
+
+    if (numTeamsLower) {
+      if (!numTeamsUpper)
+        return emitError("expected num_teams upper bound to be defined if the "
+                         "lower bound is defined");
+      if (numTeamsLower.getType() != numTeamsUpper.getType())
+        return emitError("expected num_teams upper bound and lower bound to be "
+                         "the same type");
+    }
   }
 
   // Check for allocate clause restrictions
@@ -4455,66 +4498,133 @@ LogicalResult WorkdistributeOp::verify() {
 }
 
 //===----------------------------------------------------------------------===//
-// Parser and printer for Clauses with dims modifier
+// Helper: Parse dims modifier with values
+//===----------------------------------------------------------------------===//
+// Parses: dims(N): values : type (single type for all values)
+static ParseResult parseDimsModifierWithValues(
+    OpAsmParser &parser, IntegerAttr &dimsAttr,
+    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+    SmallVectorImpl<Type> &types) {
+  if (failed(parser.parseOptionalKeyword("dims"))) {
+    return failure();
+  }
+
+  // Parse (N): values : type
+  int64_t dimsValue;
+  if (parser.parseLParen() || parser.parseInteger(dimsValue) ||
+      parser.parseRParen() || parser.parseColon()) {
+    return failure();
+  }
+
+  if (parser.parseOperandList(values) || parser.parseColon()) {
+    return failure();
+  }
+
+  // Parse single type (all values have same type)
+  Type valueType;
+  if (parser.parseType(valueType)) {
+    return failure();
+  }
+
+  // Fill types vector with same type for all values
+  types.assign(values.size(), valueType);
+
+  dimsAttr = parser.getBuilder().getI64IntegerAttr(dimsValue);
+  return success();
+}
+
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_teams clause with dims modifier
 //===----------------------------------------------------------------------===//
-// clause_name(dims(3): %v0, %v1, %v2 : i32, i32, i32)
-// clause_name(%v : i32)
 static ParseResult
-parseDimsModifier(OpAsmParser &parser, IntegerAttr &dimsAttr,
-                  SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-                  SmallVectorImpl<Type> &types) {
-  std::optional<int64_t> dims;
-  // Try to parse optional dims modifier: dims(N):
-  if (succeeded(parser.parseOptionalKeyword("dims"))) {
-    int64_t dimsValue;
-    if (parser.parseLParen() || parser.parseInteger(dimsValue) ||
-        parser.parseRParen() || parser.parseColon()) {
+parseNumTeamsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+                    SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                    SmallVectorImpl<Type> &types,
+                    std::optional<OpAsmParser::UnresolvedOperand> &lowerBound,
+                    Type &lowerBoundType,
+                    std::optional<OpAsmParser::UnresolvedOperand> &upperBound,
+                    Type &upperBoundType) {
+
+  // Format: num_teams(dims(N): values : type)
+  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+    return success();
+  }
+
+  // Format: num_teams(to upper : type)
+  if (succeeded(parser.parseOptionalKeyword("to"))) {
+    OpAsmParser::UnresolvedOperand upperOperand;
+    if (parser.parseOperand(upperOperand) || parser.parseColon() ||
+        parser.parseType(upperBoundType)) {
       return failure();
     }
-    dims = dimsValue;
+    upperBound = upperOperand;
+    return success();
   }
-  // Parse the operand list
-  if (parser.parseOperandList(values))
-    return failure();
-  // Parse colon and types
-  if (parser.parseColon() || parser.parseTypeList(types))
-    return failure();
 
-  // Verify dims matches number of values if specified
-  if (dims.has_value() && values.size() != static_cast<size_t>(*dims)) {
-    return parser.emitError(parser.getCurrentLocation())
-           << "dims(" << *dims << ") specified but " << values.size()
-           << " values provided";
+  // Format: num_teams(lower : type to upper : type)
+  OpAsmParser::UnresolvedOperand lowerOperand;
+  if (parser.parseOperand(lowerOperand) || parser.parseColon() ||
+      parser.parseType(lowerBoundType)) {
+    return failure();
   }
 
-  // If dims not specified but we have values, it's implicitly unidimensional
-  if (!dims.has_value() && values.size() != 1) {
+  if (failed(parser.parseKeyword("to"))) {
     return parser.emitError(parser.getCurrentLocation())
-           << "expected 1 value without dims modifier, but got "
-           << values.size() << " values";
+           << "expected 'to' keyword in num_teams clause";
   }
 
-  // Convert to IntegerAttr
-  if (dims.has_value()) {
-    dimsAttr = parser.getBuilder().getI64IntegerAttr(*dims);
+  OpAsmParser::UnresolvedOperand upperOperand;
+  if (parser.parseOperand(upperOperand) || parser.parseColon() ||
+      parser.parseType(upperBoundType)) {
+    return failure();
   }
+
+  lowerBound = lowerOperand;
+  upperBound = upperOperand;
   return success();
 }
 
-static void printDimsModifier(OpAsmPrinter &p, Operation *op,
-                              IntegerAttr dimsAttr, OperandRange values,
-                              TypeRange types) {
-  // Print dims modifier if present
+//===----------------------------------------------------------------------===//
+// Helper: Print dims modifier with values
+//===----------------------------------------------------------------------===//
+// Prints: dims(N): values : type (single type for all values)
+static void printDimsModifierWithValues(OpAsmPrinter &p, IntegerAttr dimsAttr,
+                                        OperandRange values, TypeRange types) {
   if (dimsAttr) {
     p << "dims(" << dimsAttr.getInt() << "): ";
   }
 
-  // Print operands
   p.printOperands(values);
 
-  // Print types
+  // Print single type
   p << " : ";
-  llvm::interleaveComma(types, p);
+  if (!types.empty()) {
+    p << types.front();
+  }
+}
+
+static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
+                                IntegerAttr dimsAttr, OperandRange values,
+                                TypeRange types, Value lowerBound,
+                                Type lowerBoundType, Value upperBound,
+                                Type upperBoundType) {
+  if (!values.empty()) {
+    // Multidimensional: dims(N): values : type
+    printDimsModifierWithValues(p, dimsAttr, values, types);
+  } else if (upperBound) {
+    if (lowerBound) {
+      // Both bounds: lower : type to upper : type
+      p.printOperand(lowerBound);
+      p << " : " << lowerBoundType << " to ";
+      p.printOperand(upperBound);
+      p << " : " << upperBoundType;
+    } else {
+      // Upper only: to upper : type
+      p << " to ";
+      p.printOperand(upperBound);
+      p << " : " << upperBoundType;
+    }
+  }
 }
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 62619f07d6573..836cac9a53707 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1451,7 +1451,82 @@ func.func @omp_teams_num_teams1(%lb : i32) {
     // expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
     "omp.teams" (%lb) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,0,1,0,0,0,0,0>} : (i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_mismatch() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i32
+    // expected-error @below {{num_teams dims(3) specified but 2 values provided}}
+    "omp.teams" (%v0, %v1) ({
+      omp.terminator
+    }) {num_teams_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i32) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_with_bounds() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i32
+    %lb = arith.constant 3 : i32
+    %ub = arith.constant 4 : i32
+    // expected-error @below {{num_teams with dims modifier cannot be used together with lower/upper bounds (unidimensional style)}}
+    "omp.teams" (%v0, %v1, %lb, %ub) ({
+      omp.terminator
+    }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,1,1,0,0,0>} : (i32, i32, i32, i32) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_values_without_dims() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i32
+    // expected-error @below {{num_teams values can only be specified with dims modifier}}
+    "omp.teams" (%v0, %v1) ({
+      omp.terminator
+    }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i32) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_no_values() {
+  omp.target {
+    // expected-error @below {{num_teams dims modifier requires values to be specified}}
+    "omp.teams" () ({
+      omp.terminator
+    }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_num_teams_dims_type_mismatch() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i64
+    // expected-error @below {{num_teams dims modifier requires all values to have the same type}}
+    "omp.teams" (%v0, %v1) ({
+      omp.terminator
+    }) {num_teams_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,2,0,0,0,0,0>} : (i32, i64) -> ()
     omp.terminator
   }
   return
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index bb154fad12742..3633a4be1eb62 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1108,8 +1108,8 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
     omp.terminator
   }
 
-  // CHECK: omp.teams num_teams_multi_dim(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32)
-  omp.teams num_teams_multi_dim(dims(3): %lb, %ub, %ub : i32, i32, i32) {
+  // CHECK: omp.teams num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+  omp.teams num_teams(dims(3): %lb, %ub, %ub : i32) {
     // CHECK: omp.terminator
     omp.terminator
   }

>From 51accd16fb27a9c7350c0d74602ae2c70097f496 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 10 Dec 2025 15:40:37 +0530
Subject: [PATCH 6/6] Mark mlir->llvmir for num_teams with dims as NYI

---
 .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp     | 12 ++++++++++++
 1 file changed, 12 insertions(+)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0d5b553c8e652..de1edde8be188 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -1981,6 +1981,10 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
   if (failed(checkImplementationStatus(*op)))
     return failure();
 
+  if (op.getNumTeamsDims().has_value() || !op.getNumTeamsValues().empty()) {
+    return op.emitError("Lowering of num_teams with dims modifier is NYI.");
+  }
+
   DenseMap<Value, llvm::Value *> reductionVariableMap;
   unsigned numReductionVars = op.getNumReductionVars();
   SmallVector<omp::DeclareReductionOp> reductionDecls;
@@ -5587,6 +5591,10 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
     for (Operation *user : blockArg.getUsers()) {
       llvm::TypeSwitch<Operation *>(user)
           .Case([&](omp::TeamsOp teamsOp) {
+            // num_teams dims and values are not yet supported
+            assert(!teamsOp.getNumTeamsDims().has_value() &&
+                   teamsOp.getNumTeamsValues().empty() &&
+                   "Lowering of num_teams with dims modifier is NYI.");
             if (teamsOp.getNumTeamsLower() == blockArg)
               numTeamsLower = hostEvalVar;
             else if (teamsOp.getNumTeamsUpper() == blockArg)
@@ -5709,6 +5717,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     // host_eval, but instead evaluated prior to entry to the region. This
     // ensures values are mapped and available inside of the target region.
     if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
+      // num_teams dims and values are not yet supported
+      assert(!teamsOp.getNumTeamsDims().has_value() &&
+             teamsOp.getNumTeamsValues().empty() &&
+             "Lowering of num_teams with dims modifier is NYI.");
       numTeamsLower = teamsOp.getNumTeamsLower();
       numTeamsUpper = teamsOp.getNumTeamsUpper();
       threadLimit = teamsOp.getThreadLimit();



More information about the Mlir-commits mailing list