[llvm-branch-commits] [flang] [mlir] [OpenMP][MLIR] Add thread_limit with dims modifier support (PR #171825)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Sat Jan 17 01:03:27 PST 2026


https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171825

>From 45e752ec35553d1ad3ae556ff65c3afd546e2e6f Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 28 Nov 2025 13:37:14 +0530
Subject: [PATCH 01/10] [OpenMP][MLIR] Add num_teams clause with dims modifier
 support

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 72 +++++++++++++++++++
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  3 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  5 ++
 3 files changed, 79 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index b612d4e136baf..ed24530464ea4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1567,4 +1567,76 @@ class OpenMP_UseDevicePtrClauseSkip<
 
 def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
 
+//===----------------------------------------------------------------------===//
+// V6.2: Multidimensional `num_teams` clause with dims modifier
+//===----------------------------------------------------------------------===//
+
+class OpenMP_NumTeamsMultiDimClauseSkip<
+    bit traits = false, bit arguments = false, bit assemblyFormat = false,
+    bit description = false, bit extraClassDeclaration = false
+  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
+            extraClassDeclaration> {
+  let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
+    Variadic<AnyInteger>:$num_teams_values
+  );
+
+  let optAssemblyFormat = [{
+    `num_teams_multi_dim` `(` custom<NumTeamsMultiDimClause>($num_teams_dims,
+                                                    $num_teams_values,
+                                                    type($num_teams_values)) `)`
+  }];
+
+  let description = [{
+    The `num_teams_multi_dim` clause with dims modifier support specifies the limit on
+    the number of teams to be created in a multidimensional team space.
+
+    The dims modifier for the num_teams_multi_dim clause specifies the number of
+    dimensions for the league space (team space) that the clause arranges.
+    The dimensions argument in the dims modifier specifies the number of
+    dimensions and determines the length of the list argument. The list items
+    are specified in ascending order according to the ordinal number of the
+    dimensions (dimension 0, 1, 2, ..., N-1).
+
+    - If `dims` is not specified: The space is unidimensional (1D) with a single value
+    - If `dims(1)` is specified: The space is explicitly unidimensional (1D)
+    - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D)
+
+    **Examples:**
+    - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a
+      3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2.
+    - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt.
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasDimsModifier() {
+      return getNumTeamsDims().has_value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    /// Returns 1 if dims modifier is not present (unidimensional by default)
+    unsigned getNumDimensions() {
+      if (!hasDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getNumTeamsDims());
+    }
+
+    /// Returns all dimension values as an operand range
+    ::mlir::OperandRange getDimensionValues() {
+      return getNumTeamsValues();
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumDimensions()
+    ::mlir::Value getDimensionValue(unsigned index) {
+    assert(index < getDimensionValues().size() &&
+        "Dimension index out of bounds");
+      return getDimensionValues()[index];
+    }
+  }];
+}
+
+def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>;
+
 #endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index d4e8cecda2601..76eeb0bd70ec3 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,7 +241,8 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
     AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
   ], clauses = [
     OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
-    OpenMP_PrivateClause, OpenMP_ReductionClause, OpenMP_ThreadLimitClause
+    OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
+    OpenMP_ThreadLimitClause
   ], singleRegion = true> {
   let summary = "teams construct";
   let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 25bf4e70d9a83..7a9a45b160ba3 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2625,8 +2625,13 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+<<<<<<< HEAD
                  clauses.ifExpr, clauses.numTeamsVals, clauses.numTeamsLower,
                  clauses.numTeamsUpper,
+=======
+                 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
+                 clauses.numTeamsDims, clauses.numTeamsValues,
+>>>>>>> [OpenMP][MLIR] Add num_teams clause with dims modifier support
                  /*private_vars=*/{}, /*private_syms=*/nullptr,
                  /*private_needs_barrier=*/nullptr, clauses.reductionMod,
                  clauses.reductionVars,

>From d5c8a0c4f80749882fea42bf86d1ed71f03a5c83 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 16 Jan 2026 08:48:30 +0530
Subject: [PATCH 02/10] Update num_teams to have just the list and no dims(N)
 syntax

---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 2 +-
 mlir/test/Dialect/OpenMP/ops.mlir            | 4 ++--
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 7a9a45b160ba3..d8817617f5c6c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4616,7 +4616,7 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
       p << " : " << upperBoundType;
     } else {
       // Upper only: to upper : type
-      p << " to ";
+      p << "to ";
       p.printOperand(upperBound);
       p << " : " << upperBoundType;
     }
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 49a88e0443e60..0624b31844fc4 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1103,7 +1103,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
     omp.terminator
   }
 
-  // CHECK: omp.teams num_teams( to %{{.+}} : i32)
+  // CHECK: omp.teams num_teams(to %{{.+}} : i32)
   omp.teams num_teams(to %ub : i32) {
     // CHECK: omp.terminator
     omp.terminator
@@ -3084,7 +3084,7 @@ func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?
 
 func.func @omp_target_host_eval(%x : i32) {
   // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
-  // CHECK: omp.teams num_teams( to %[[HOST_ARG]] : i32)
+  // CHECK: omp.teams num_teams(to %[[HOST_ARG]] : i32)
   // CHECK-SAME: thread_limit(%[[HOST_ARG]] : i32)
   omp.target host_eval(%x -> %arg0 : i32) {
     omp.teams num_teams( to %arg0 : i32) thread_limit(%arg0 : i32) {

>From f83855f2eb6922042c2ed1f287096c9c778b13bf Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 13:35:05 +0530
Subject: [PATCH 03/10] [OpenMP][MLIR] Add thread_limit with dims modifier
 support

---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  |  16 +-
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  29 +++-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  69 ++++++++-
 mlir/test/Dialect/OpenMP/invalid.mlir         | 139 +++++++++++++++++-
 mlir/test/Dialect/OpenMP/ops.mlir             |   8 +-
 5 files changed, 249 insertions(+), 12 deletions(-)

diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 7b61539984232..a3b9e5c76bdd2 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -766,6 +766,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
       targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
       innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
       targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
       targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
   rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
                               newTargetOp.getRegion().begin());
@@ -1485,8 +1486,9 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
       targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(),
       targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
-      targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
-      targetOp.getPrivateMapsAttr());
+      targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
   auto *preTargetBlock = rewriter.createBlock(
       &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
   IRMapping preMapping;
@@ -1575,8 +1577,9 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
       targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
       targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
-      targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
-      targetOp.getPrivateMapsAttr());
+      targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
   auto *isolatedTargetBlock =
       rewriter.createBlock(&isolatedTargetOp.getRegion(),
                            isolatedTargetOp.getRegion().begin(), {}, {});
@@ -1655,8 +1658,9 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
       targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
       targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
-      targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimit(),
-      targetOp.getPrivateMapsAttr());
+      targetOp.getPrivateNeedsBarrierAttr(),
+      targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
   // Create the block for postTargetOp
   auto *postTargetBlock = rewriter.createBlock(
       &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index ed24530464ea4..45a27722f968d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1452,16 +1452,43 @@ class OpenMP_ThreadLimitClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims,
+    Variadic<AnyInteger>:$thread_limit_dims_values,
     Optional<AnyInteger>:$thread_limit
   );
 
   let optAssemblyFormat = [{
-    `thread_limit` `(` $thread_limit `:` type($thread_limit) `)`
+    `thread_limit` `(` custom<ThreadLimitClause>(
+      $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values),
+      $thread_limit, type($thread_limit)
+    ) `)`
   }];
 
   let description = [{
     The optional `thread_limit` specifies the limit on the number of threads.
   }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasThreadLimitDimsModifier() {
+      return getThreadLimitNumDims().has_value() && getThreadLimitNumDims().value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    unsigned getThreadLimitDimsCount() {
+      if (!hasThreadLimitDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getThreadLimitNumDims());
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getThreadLimitDimsCount()
+    ::mlir::Value getThreadLimitDimensionValue(unsigned index) {
+      assert(index < getThreadLimitDimsCount() &&
+             "Thread limit dims index out of bounds");
+      return getThreadLimitDimsValues()[index];
+    }
+  }];
 }
 
 def OpenMP_ThreadLimitClause : OpenMP_ThreadLimitClauseSkip<>;
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index d8817617f5c6c..a4a669eaa77c0 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2210,10 +2210,30 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
                   makeArrayAttr(ctx, clauses.privateSyms),
-                  clauses.privateNeedsBarrier, clauses.threadLimit,
+                  clauses.privateNeedsBarrier, clauses.threadLimitNumDims,
+                  clauses.threadLimitDimsValues, clauses.threadLimit,
                   /*private_maps=*/nullptr);
 }
 
+// helper for thread_limit clause restrictions
+static LogicalResult
+verifyThreadLimitClause(Operation *op,
+                        std::optional<IntegerAttr> threadLimitNumDims,
+                        OperandRange threadLimitDimsValues, Value threadLimit) {
+  bool hasDimsModifier =
+      threadLimitNumDims.has_value() && threadLimitNumDims.value();
+
+  if (hasDimsModifier && threadLimit) {
+    return op->emitError("thread_limit with dims modifier cannot be used "
+                         "together with number of threads");
+  }
+
+  if (failed(verifyDimsModifier(op, threadLimitNumDims, threadLimitDimsValues)))
+    return failure();
+
+  return success();
+}
+
 LogicalResult TargetOp::verify() {
   if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
     return failure();
@@ -2225,6 +2245,11 @@ LogicalResult TargetOp::verify() {
   if (failed(verifyMapClause(*this, getMapVars())))
     return failure();
 
+  if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
+                                     getThreadLimitDimsValues(),
+                                     getThreadLimit())))
+    return failure();
+
   return verifyPrivateVarsMapping(*this);
 }
 
@@ -2692,6 +2717,12 @@ LogicalResult TeamsOp::verify() {
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
+  // Check for thread_limit clause restrictions
+  if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
+                                     getThreadLimitDimsValues(),
+                                     getThreadLimit())))
+    return failure();
+
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());
 }
@@ -4623,6 +4654,42 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Parser and printer for thread_limit clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseThreadLimitClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                       SmallVectorImpl<Type> &types,
+                       std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+                       Type &boundsType) {
+  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+    return success();
+  }
+
+  OpAsmParser::UnresolvedOperand boundsOperand;
+  if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+      parser.parseType(boundsType)) {
+    return failure();
+  }
+  bounds = boundsOperand;
+  return success();
+}
+
+static void printThreadLimitClause(OpAsmPrinter &p, Operation *op,
+                                   IntegerAttr dimsAttr, OperandRange values,
+                                   TypeRange types, Value bounds,
+                                   Type boundsType) {
+  if (!values.empty()) {
+    // Multidimensional: dims(N): values : type
+    printDimsModifierWithValues(p, dimsAttr, values, types);
+  } else if (bounds) {
+    // Both bounds: bounds : type
+    p.printOperand(bounds);
+    p << " : " << boundsType;
+  }
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index d451b14e8bfc9..2b030f2a775c4 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
     // expected-error @below {{expected equal sizes for allocate and allocator variables}}
     "omp.teams" (%data_var) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+    }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
     omp.terminator
   }
   return
@@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
     // expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
     "omp.teams" (%lb) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0,0>} : (i32) -> ()
     omp.terminator
   }
   return
@@ -1489,6 +1489,139 @@ func.func @omp_teams_num_teams2(%lb : i32, %ub : i16) {
 
 // -----
 
+func.func @omp_teams_thread_limit_dims_mismatch() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i32
+    // expected-error @below {{dims(3) specified but 2 values provided}}
+    "omp.teams" (%v0, %v1) ({
+      omp.terminator
+    }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_thread_limit_dims_with_scalar() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i32
+    %tl = arith.constant 4 : i32
+    // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}}
+    "omp.teams" (%v0, %v1, %tl) ({
+      omp.terminator
+    }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_thread_limit_dims_no_values() {
+  omp.target {
+    // expected-error @below {{dims modifier requires values to be specified}}
+    "omp.teams" () ({
+      omp.terminator
+    }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_thread_limit_values_without_dims() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i32
+    // expected-error @below {{dims values can only be specified with dims modifier}}
+    "omp.teams" (%v0, %v1) ({
+      omp.terminator
+    }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_teams_thread_limit_dims_type_mismatch() {
+  omp.target {
+    %v0 = arith.constant 1 : i32
+    %v1 = arith.constant 2 : i64
+    // expected-error @below {{dims modifier requires all values to have the same type}}
+    "omp.teams" (%v0, %v1) ({
+      omp.terminator
+    }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> ()
+    omp.terminator
+  }
+  return
+}
+
+// -----
+
+func.func @omp_target_thread_limit_dims_mismatch() {
+  %v0 = arith.constant 1 : i32
+  %v1 = arith.constant 2 : i32
+  // expected-error @below {{dims(3) specified but 2 values provided}}
+  "omp.target" (%v0, %v1) ({
+    omp.terminator
+  }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+  return
+}
+
+// -----
+
+func.func @omp_target_thread_limit_dims_with_scalar() {
+  %v0 = arith.constant 1 : i32
+  %v1 = arith.constant 2 : i32
+  %tl = arith.constant 4 : i32
+  // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}}
+  "omp.target" (%v0, %v1, %tl) ({
+    omp.terminator
+  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> ()
+  return
+}
+
+// -----
+
+func.func @omp_target_thread_limit_dims_no_values() {
+  // expected-error @below {{dims modifier requires values to be specified}}
+  "omp.target" () ({
+    omp.terminator
+  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0,0>} : () -> ()
+  return
+}
+
+// -----
+
+func.func @omp_target_thread_limit_values_without_dims() {
+  %v0 = arith.constant 1 : i32
+  %v1 = arith.constant 2 : i32
+  // expected-error @below {{dims values can only be specified with dims modifier}}
+  "omp.target" (%v0, %v1) ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+  return
+}
+
+// -----
+
+func.func @omp_target_thread_limit_dims_type_mismatch() {
+  %v0 = arith.constant 1 : i32
+  %v1 = arith.constant 2 : i64
+  // expected-error @below {{dims modifier requires all values to have the same type}}
+  "omp.target" (%v0, %v1) ({
+    omp.terminator
+  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> ()
+  return
+}
+
+// -----
+
 func.func @omp_sections(%data_var : memref<i32>) -> () {
   // expected-error @below {{expected equal sizes for allocate and allocator variables}}
   "omp.sections" (%data_var) ({
@@ -2475,7 +2608,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
       "omp.terminator"() : () -> ()
-    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
    "func.return"() : () -> ()
 }
 
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 0624b31844fc4..de5f604e0706d 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -824,7 +824,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%device, %if_cond, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}
@@ -1136,6 +1136,12 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
     omp.terminator
   }
 
+  // CHECK: omp.teams thread_limit(dims(2): %{{.*}}, %{{.*}} : i32)
+  omp.teams thread_limit(dims(2): %lb, %ub : i32) {
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+
   // Test reduction.
   %c1 = arith.constant 1 : i32
   %0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr

>From fdff6229223b9e714aee10e2b83448109aa8db55 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 18:35:55 +0530
Subject: [PATCH 04/10] update thread_limit description

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 45a27722f968d..66c3a47c0b5ec 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1465,7 +1465,18 @@ class OpenMP_ThreadLimitClauseSkip<
   }];
 
   let description = [{
-    The optional `thread_limit` specifies the limit on the number of threads.
+    The `thread_limit` clause specifies the limit on the number of threads.
+
+    With dims modifier:
+    - The number of dimensions is specified by the `thread_limit_num_dims` attribute.
+    - The values for each dimension are specified by the `thread_limit_dims_values` attribute.
+    - Format: `thread_limit(dims(N): values : type)`
+    - Example: `thread_limit(dims(2): %n, %m : i64)`
+
+    Without dims modifier:
+    - The number of threads is specified by the `thread_limit`.
+    - Format: `thread_limit(number_of_threads : type)`
+    - Example: `thread_limit(%n : i64)`
   }];
 
   let extraClassDeclaration = [{

>From 487fe3734ade3c0bf1e5aaff717bff9719cbd3eb Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 17 Dec 2025 09:06:36 +0530
Subject: [PATCH 05/10] Remove separate thread_limit argument from clause

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    |  3 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 16 +++---
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  |  8 +--
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 16 +++---
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 54 +++++++------------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 11 ++--
 mlir/test/Dialect/OpenMP/invalid.mlir         | 28 +++++-----
 mlir/test/Dialect/OpenMP/ops.mlir             |  2 +-
 8 files changed, 63 insertions(+), 75 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 2f531efaf09aa..9ee0a55243a2c 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -668,8 +668,9 @@ bool ClauseProcessor::processThreadLimit(
     lower::StatementContext &stmtCtx,
     mlir::omp::ThreadLimitClauseOps &result) const {
   if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
-    result.threadLimit =
+    mlir::Value threadLimitVal =
         fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+    result.threadLimitDimsValues.push_back(threadLimitVal);
     return true;
   }
   return false;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 0764693f748a5..259771c523e93 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -102,8 +102,9 @@ class HostEvalInfo {
     if (ops.numThreads)
       vars.push_back(ops.numThreads);
 
-    if (ops.threadLimit)
-      vars.push_back(ops.threadLimit);
+    // Old spec: single value in threadLimitDimsValues
+    for (mlir::Value val : ops.threadLimitDimsValues)
+      vars.push_back(val);
   }
 
   /// Update \c ops, replacing all values with the corresponding block argument
@@ -116,7 +117,7 @@ class HostEvalInfo {
                ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
                    ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
                    (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
-                   (ops.threadLimit ? 1 : 0) &&
+                   ops.threadLimitDimsValues.size() &&
            "invalid block argument list");
     int argIndex = 0;
     for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
@@ -137,8 +138,8 @@ class HostEvalInfo {
     if (ops.numThreads)
       ops.numThreads = args[argIndex++];
 
-    if (ops.threadLimit)
-      ops.threadLimit = args[argIndex++];
+    for (size_t i = 0; i < ops.threadLimitDimsValues.size(); ++i)
+      ops.threadLimitDimsValues[i] = args[argIndex++];
   }
 
   /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated
@@ -185,12 +186,13 @@ class HostEvalInfo {
   /// \returns whether an update was performed. If not, these clauses were not
   ///          evaluated in the host device.
   bool apply(mlir::omp::TeamsOperands &clauseOps) {
-    if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit)
+    if (!ops.numTeamsLower && !ops.numTeamsUpper &&
+        ops.threadLimitDimsValues.empty())
       return false;
 
     clauseOps.numTeamsLower = ops.numTeamsLower;
     clauseOps.numTeamsUpper = ops.numTeamsUpper;
-    clauseOps.threadLimit = ops.threadLimit;
+    clauseOps.threadLimitDimsValues = ops.threadLimitDimsValues;
     return true;
   }
 
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index a3b9e5c76bdd2..4d3fec3b0710f 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -767,7 +767,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
       innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
       targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
       targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
-      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+      targetOp.getPrivateMapsAttr());
   rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
                               newTargetOp.getRegion().begin());
   rewriter.replaceOp(targetOp, targetDataOp);
@@ -1488,7 +1488,7 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
       targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
       targetOp.getPrivateNeedsBarrierAttr(),
       targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
-      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+      targetOp.getPrivateMapsAttr());
   auto *preTargetBlock = rewriter.createBlock(
       &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
   IRMapping preMapping;
@@ -1579,7 +1579,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
       targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
       targetOp.getPrivateNeedsBarrierAttr(),
       targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
-      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+      targetOp.getPrivateMapsAttr());
   auto *isolatedTargetBlock =
       rewriter.createBlock(&isolatedTargetOp.getRegion(),
                            isolatedTargetOp.getRegion().begin(), {}, {});
@@ -1660,7 +1660,7 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
       targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
       targetOp.getPrivateNeedsBarrierAttr(),
       targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
-      targetOp.getThreadLimit(), targetOp.getPrivateMapsAttr());
+      targetOp.getPrivateMapsAttr());
   // Create the block for postTargetOp
   auto *postTargetBlock = rewriter.createBlock(
       &postTargetOp.getRegion(), postTargetOp.getRegion().begin(), {}, {});
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 66c3a47c0b5ec..d350ff71c19b5 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1453,14 +1453,12 @@ class OpenMP_ThreadLimitClauseSkip<
                     extraClassDeclaration> {
   let arguments = (ins
     ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims,
-    Variadic<AnyInteger>:$thread_limit_dims_values,
-    Optional<AnyInteger>:$thread_limit
+    Variadic<AnyInteger>:$thread_limit_dims_values
   );
 
   let optAssemblyFormat = [{
     `thread_limit` `(` custom<ThreadLimitClause>(
-      $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values),
-      $thread_limit, type($thread_limit)
+      $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values)
     ) `)`
   }];
 
@@ -1468,14 +1466,14 @@ class OpenMP_ThreadLimitClauseSkip<
     The `thread_limit` clause specifies the limit on the number of threads.
 
     With dims modifier:
-    - The number of dimensions is specified by the `thread_limit_num_dims` attribute.
-    - The values for each dimension are specified by the `thread_limit_dims_values` attribute.
+    - The number of dimensions is specified by the `thread_limit_num_dims`.
+    - The values for each dimension are specified by the `thread_limit_dims_values`.
     - Format: `thread_limit(dims(N): values : type)`
     - Example: `thread_limit(dims(2): %n, %m : i64)`
 
     Without dims modifier:
-    - The number of threads is specified by the `thread_limit`.
-    - Format: `thread_limit(number_of_threads : type)`
+    - The number of threads is specified by the single value in `thread_limit_dims_values`.
+    - Format: `thread_limit(value : type)`
     - Example: `thread_limit(%n : i64)`
   }];
 
@@ -1497,6 +1495,8 @@ class OpenMP_ThreadLimitClauseSkip<
     ::mlir::Value getThreadLimitDimensionValue(unsigned index) {
       assert(index < getThreadLimitDimsCount() &&
              "Thread limit dims index out of bounds");
+      if (getThreadLimitDimsValues().empty())
+        return nullptr;
       return getThreadLimitDimsValues()[index];
     }
   }];
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index a4a669eaa77c0..7f3e0d740ffcc 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2211,7 +2211,7 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
                   makeArrayAttr(ctx, clauses.privateSyms),
                   clauses.privateNeedsBarrier, clauses.threadLimitNumDims,
-                  clauses.threadLimitDimsValues, clauses.threadLimit,
+                  clauses.threadLimitDimsValues,
                   /*private_maps=*/nullptr);
 }
 
@@ -2219,15 +2219,7 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
 static LogicalResult
 verifyThreadLimitClause(Operation *op,
                         std::optional<IntegerAttr> threadLimitNumDims,
-                        OperandRange threadLimitDimsValues, Value threadLimit) {
-  bool hasDimsModifier =
-      threadLimitNumDims.has_value() && threadLimitNumDims.value();
-
-  if (hasDimsModifier && threadLimit) {
-    return op->emitError("thread_limit with dims modifier cannot be used "
-                         "together with number of threads");
-  }
-
+                        OperandRange threadLimitDimsValues) {
   if (failed(verifyDimsModifier(op, threadLimitNumDims, threadLimitDimsValues)))
     return failure();
 
@@ -2246,8 +2238,7 @@ LogicalResult TargetOp::verify() {
     return failure();
 
   if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
-                                     getThreadLimitDimsValues(),
-                                     getThreadLimit())))
+                                     getThreadLimitDimsValues())))
     return failure();
 
   return verifyPrivateVarsMapping(*this);
@@ -2265,10 +2256,9 @@ LogicalResult TargetOp::verifyRegions() {
        cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
     for (Operation *user : hostEvalArg.getUsers()) {
       if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
-        if (llvm::is_contained({teamsOp.getNumTeamsLower(),
-                                teamsOp.getNumTeamsUpper(),
-                                teamsOp.getThreadLimit()},
-                               hostEvalArg))
+        if (teamsOp.getNumTeamsLower() == hostEvalArg ||
+            teamsOp.getNumTeamsUpper() == hostEvalArg ||
+            llvm::is_contained(teamsOp.getThreadLimitDimsValues(), hostEvalArg))
           continue;
 
         return emitOpError() << "host_eval argument only legal as 'num_teams' "
@@ -2719,8 +2709,7 @@ LogicalResult TeamsOp::verify() {
 
   // Check for thread_limit clause restrictions
   if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
-                                     getThreadLimitDimsValues(),
-                                     getThreadLimit())))
+                                     getThreadLimitDimsValues())))
     return failure();
 
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
@@ -4660,34 +4649,29 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
 static ParseResult
 parseThreadLimitClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
                        SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-                       SmallVectorImpl<Type> &types,
-                       std::optional<OpAsmParser::UnresolvedOperand> &bounds,
-                       Type &boundsType) {
+                       SmallVectorImpl<Type> &types) {
+  // Try parsing with dims modifier: dims(N): values : type
   if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
     return success();
   }
 
-  OpAsmParser::UnresolvedOperand boundsOperand;
-  if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
-      parser.parseType(boundsType)) {
+  // Without dims modifier: value : type
+  OpAsmParser::UnresolvedOperand singleValue;
+  Type singleType;
+  if (parser.parseOperand(singleValue) || parser.parseColon() ||
+      parser.parseType(singleType)) {
     return failure();
   }
-  bounds = boundsOperand;
+  values.push_back(singleValue);
+  types.push_back(singleType);
   return success();
 }
 
 static void printThreadLimitClause(OpAsmPrinter &p, Operation *op,
                                    IntegerAttr dimsAttr, OperandRange values,
-                                   TypeRange types, Value bounds,
-                                   Type boundsType) {
-  if (!values.empty()) {
-    // Multidimensional: dims(N): values : type
-    printDimsModifierWithValues(p, dimsAttr, values, types);
-  } else if (bounds) {
-    // Both bounds: bounds : type
-    p.printOperand(bounds);
-    p << " : " << boundsType;
-  }
+                                   TypeRange types) {
+  // Multidimensional: dims(N): values : type
+  printDimsModifierWithValues(p, dimsAttr, values, types);
 }
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 0b7bf64cefe4c..68166c3cf7570 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -2075,7 +2075,7 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
     numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
 
   llvm::Value *threadLimit = nullptr;
-  if (Value threadLimitVar = op.getThreadLimit())
+  if (Value threadLimitVar = op.getThreadLimitDimensionValue(0))
     threadLimit = moduleTranslation.lookupValue(threadLimitVar);
 
   llvm::Value *ifExpr = nullptr;
@@ -6044,7 +6044,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               numTeamsLower = hostEvalVar;
             else if (teamsOp.getNumTeamsUpper() == blockArg)
               numTeamsUpper = hostEvalVar;
-            else if (teamsOp.getThreadLimit() == blockArg)
+            else if (teamsOp.getThreadLimitDimensionValue(0) == blockArg)
               threadLimit = hostEvalVar;
             else
               llvm_unreachable("unsupported host_eval use");
@@ -6164,7 +6164,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
       numTeamsLower = teamsOp.getNumTeamsLower();
       numTeamsUpper = teamsOp.getNumTeamsUpper();
-      threadLimit = teamsOp.getThreadLimit();
+      threadLimit = teamsOp.getThreadLimitDimensionValue(0);
     }
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
@@ -6209,7 +6209,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
 
   // Extract 'thread_limit' clause from 'target' and 'teams' directives.
   int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
-  setMaxValueFromClause(targetOp.getThreadLimit(), targetThreadLimitVal);
+  setMaxValueFromClause(targetOp.getThreadLimitDimensionValue(0),
+                        targetThreadLimitVal);
   setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
 
   // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
@@ -6288,7 +6289,7 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
                          teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
 
   // TODO: Handle constant 'if' clauses.
-  if (Value targetThreadLimit = targetOp.getThreadLimit())
+  if (Value targetThreadLimit = targetOp.getThreadLimitDimensionValue(0))
     attrs.TargetThreadLimit.front() =
         moduleTranslation.lookupValue(targetThreadLimit);
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 2b030f2a775c4..14ee2948d3634 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1438,7 +1438,7 @@ func.func @omp_teams_allocate(%data_var : memref<i32>) {
     // expected-error @below {{expected equal sizes for allocate and allocator variables}}
     "omp.teams" (%data_var) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
+    }) {operandSegmentSizes = array<i32: 1,0,0,0,0,0,0,0,0>} : (memref<i32>) -> ()
     omp.terminator
   }
   return
@@ -1451,7 +1451,7 @@ func.func @omp_teams_num_teams1(%lb : i32) {
     // expected-error @below {{expected num_teams upper bound to be defined if the lower bound is defined}}
     "omp.teams" (%lb) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0,0>} : (i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,0,0,1,0,0,0,0>} : (i32) -> ()
     omp.terminator
   }
   return
@@ -1496,7 +1496,7 @@ func.func @omp_teams_thread_limit_dims_mismatch() {
     // expected-error @below {{dims(3) specified but 2 values provided}}
     "omp.teams" (%v0, %v1) ({
       omp.terminator
-    }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+    }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
     omp.terminator
   }
   return
@@ -1509,10 +1509,10 @@ func.func @omp_teams_thread_limit_dims_with_scalar() {
     %v0 = arith.constant 1 : i32
     %v1 = arith.constant 2 : i32
     %tl = arith.constant 4 : i32
-    // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}}
+    // expected-error @below {{dims(2) specified but 3 values provided}}
     "omp.teams" (%v0, %v1, %tl) ({
       omp.terminator
-    }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> ()
+    }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> ()
     omp.terminator
   }
   return
@@ -1540,7 +1540,7 @@ func.func @omp_teams_thread_limit_values_without_dims() {
     // expected-error @below {{dims values can only be specified with dims modifier}}
     "omp.teams" (%v0, %v1) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
     omp.terminator
   }
   return
@@ -1555,7 +1555,7 @@ func.func @omp_teams_thread_limit_dims_type_mismatch() {
     // expected-error @below {{dims modifier requires all values to have the same type}}
     "omp.teams" (%v0, %v1) ({
       omp.terminator
-    }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> ()
+    }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i64) -> ()
     omp.terminator
   }
   return
@@ -1569,7 +1569,7 @@ func.func @omp_target_thread_limit_dims_mismatch() {
   // expected-error @below {{dims(3) specified but 2 values provided}}
   "omp.target" (%v0, %v1) ({
     omp.terminator
-  }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+  }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
   return
 }
 
@@ -1579,10 +1579,10 @@ func.func @omp_target_thread_limit_dims_with_scalar() {
   %v0 = arith.constant 1 : i32
   %v1 = arith.constant 2 : i32
   %tl = arith.constant 4 : i32
-  // expected-error @below {{thread_limit with dims modifier cannot be used together with number of threads}}
+  // expected-error @below {{dims(2) specified but 3 values provided}}
   "omp.target" (%v0, %v1, %tl) ({
     omp.terminator
-  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,1>} : (i32, i32, i32) -> ()
+  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> ()
   return
 }
 
@@ -1592,7 +1592,7 @@ func.func @omp_target_thread_limit_dims_no_values() {
   // expected-error @below {{dims modifier requires values to be specified}}
   "omp.target" () ({
     omp.terminator
-  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0,0>} : () -> ()
+  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0>} : () -> ()
   return
 }
 
@@ -1604,7 +1604,7 @@ func.func @omp_target_thread_limit_values_without_dims() {
   // expected-error @below {{dims values can only be specified with dims modifier}}
   "omp.target" (%v0, %v1) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i32) -> ()
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
   return
 }
 
@@ -1616,7 +1616,7 @@ func.func @omp_target_thread_limit_dims_type_mismatch() {
   // expected-error @below {{dims modifier requires all values to have the same type}}
   "omp.target" (%v0, %v1) ({
     omp.terminator
-  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2,0>} : (i32, i64) -> ()
+  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i64) -> ()
   return
 }
 
@@ -2608,7 +2608,7 @@ func.func @omp_target_depend(%data_var: memref<i32>) {
   // expected-error @below {{op expected as many depend values as depend variables}}
     "omp.target"(%data_var) ({
       "omp.terminator"() : () -> ()
-    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
+    }) {depend_kinds = [], operandSegmentSizes = array<i32: 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0>} : (memref<i32>) -> ()
    "func.return"() : () -> ()
 }
 
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index de5f604e0706d..0255b1eb6f10f 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -824,7 +824,7 @@ func.func @omp_target(%if_cond : i1, %device : si32,  %num_threads : i32, %devic
     "omp.target"(%device, %if_cond, %num_threads) ({
        // CHECK: omp.terminator
        omp.terminator
-    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
+    }) {nowait, operandSegmentSizes = array<i32: 0,0,0,1,0,0,1,0,0,0,0,1>} : ( si32, i1, i32 ) -> ()
 
     // Test with optional map clause.
     // CHECK: %[[MAP_A:.*]] = omp.map.info var_ptr(%[[VAL_1:.*]] : memref<?xi32>, tensor<?xi32>)   map_clauses(always, to) capture(ByRef) -> memref<?xi32> {name = ""}

>From 014e0c3e0345cc9e698e4b10c2a73d24d2879e3d Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 19 Dec 2025 10:11:01 +0530
Subject: [PATCH 06/10] comments fixes

---
 mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index d350ff71c19b5..3c394dc67cbe1 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1453,7 +1453,7 @@ class OpenMP_ThreadLimitClauseSkip<
                     extraClassDeclaration> {
   let arguments = (ins
     ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims,
-    Variadic<AnyInteger>:$thread_limit_dims_values
+    Variadic<IntLikeType>:$thread_limit_dims_values
   );
 
   let optAssemblyFormat = [{

>From ebb7571444cb1cd4ed9643336d74fd50e0b726d7 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 19 Dec 2025 15:07:17 +0530
Subject: [PATCH 07/10] fix comment

---
 .../Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp   | 1 +
 1 file changed, 1 insertion(+)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 68166c3cf7570..d491925505abb 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -6040,6 +6040,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
     for (Operation *user : blockArg.getUsers()) {
       llvm::TypeSwitch<Operation *>(user)
           .Case([&](omp::TeamsOp teamsOp) {
+            // num_teams dims and values are not yet supported
             if (teamsOp.getNumTeamsLower() == blockArg)
               numTeamsLower = hostEvalVar;
             else if (teamsOp.getNumTeamsUpper() == blockArg)

>From 4c401b11cf15a820cc7af6cb035be0a06bf61ba2 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 Jan 2026 12:19:24 +0530
Subject: [PATCH 08/10] [Flang] Add missing threadLimitNumDims in TeamsOperands
 apply method

---
 flang/lib/Lower/OpenMP/OpenMP.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 259771c523e93..26c1e0166f393 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -193,6 +193,7 @@ class HostEvalInfo {
     clauseOps.numTeamsLower = ops.numTeamsLower;
     clauseOps.numTeamsUpper = ops.numTeamsUpper;
     clauseOps.threadLimitDimsValues = ops.threadLimitDimsValues;
+    clauseOps.threadLimitNumDims = ops.threadLimitNumDims;
     return true;
   }
 

>From 7cea2d521763d70509094d237908d4e460fc56d5 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 16 Jan 2026 09:55:34 +0530
Subject: [PATCH 09/10] remove dims(N) syntax and just use list for dims vals

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    |   2 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp             |  15 +-
 .../Optimizer/OpenMP/LowerWorkdistribute.cpp  |  12 +-
 flang/test/Lower/OpenMP/teams.f90             |   2 +-
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      |  48 +++----
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  |  56 ++------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  26 ++--
 mlir/test/Dialect/OpenMP/invalid.mlir         | 133 ------------------
 mlir/test/Dialect/OpenMP/ops.mlir             |  15 +-
 mlir/test/Target/LLVMIR/openmp-todo.mlir      |  11 ++
 10 files changed, 88 insertions(+), 232 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 9ee0a55243a2c..d487f5d686b67 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -670,7 +670,7 @@ bool ClauseProcessor::processThreadLimit(
   if (auto *clause = findUniqueClause<omp::clause::ThreadLimit>()) {
     mlir::Value threadLimitVal =
         fir::getBase(converter.genExprValue(clause->v, stmtCtx));
-    result.threadLimitDimsValues.push_back(threadLimitVal);
+    result.threadLimitVals.push_back(threadLimitVal);
     return true;
   }
   return false;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 26c1e0166f393..670143fd9b1c4 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -102,8 +102,7 @@ class HostEvalInfo {
     if (ops.numThreads)
       vars.push_back(ops.numThreads);
 
-    // Old spec: single value in threadLimitDimsValues
-    for (mlir::Value val : ops.threadLimitDimsValues)
+    for (mlir::Value val : ops.threadLimitVals)
       vars.push_back(val);
   }
 
@@ -117,7 +116,7 @@ class HostEvalInfo {
                ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
                    ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
                    (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
-                   ops.threadLimitDimsValues.size() &&
+                   ops.threadLimitVals.size() &&
            "invalid block argument list");
     int argIndex = 0;
     for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
@@ -138,8 +137,8 @@ class HostEvalInfo {
     if (ops.numThreads)
       ops.numThreads = args[argIndex++];
 
-    for (size_t i = 0; i < ops.threadLimitDimsValues.size(); ++i)
-      ops.threadLimitDimsValues[i] = args[argIndex++];
+    for (size_t i = 0; i < ops.threadLimitVals.size(); ++i)
+      ops.threadLimitVals[i] = args[argIndex++];
   }
 
   /// Update \p clauseOps and \p ivOut with the corresponding host-evaluated
@@ -186,14 +185,12 @@ class HostEvalInfo {
   /// \returns whether an update was performed. If not, these clauses were not
   ///          evaluated in the host device.
   bool apply(mlir::omp::TeamsOperands &clauseOps) {
-    if (!ops.numTeamsLower && !ops.numTeamsUpper &&
-        ops.threadLimitDimsValues.empty())
+    if (!ops.numTeamsLower && !ops.numTeamsUpper && ops.threadLimitVals.empty())
       return false;
 
     clauseOps.numTeamsLower = ops.numTeamsLower;
     clauseOps.numTeamsUpper = ops.numTeamsUpper;
-    clauseOps.threadLimitDimsValues = ops.threadLimitDimsValues;
-    clauseOps.threadLimitNumDims = ops.threadLimitNumDims;
+    clauseOps.threadLimitVals = ops.threadLimitVals;
     return true;
   }
 
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
index 4d3fec3b0710f..b804a14e32f0c 100644
--- a/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkdistribute.cpp
@@ -766,8 +766,7 @@ FailureOr<omp::TargetOp> splitTargetData(omp::TargetOp targetOp,
       targetOp.getInReductionSymsAttr(), targetOp.getIsDevicePtrVars(),
       innerMapInfos, targetOp.getNowaitAttr(), targetOp.getPrivateVars(),
       targetOp.getPrivateSymsAttr(), targetOp.getPrivateNeedsBarrierAttr(),
-      targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
-      targetOp.getPrivateMapsAttr());
+      targetOp.getThreadLimitVals(), targetOp.getPrivateMapsAttr());
   rewriter.inlineRegionBefore(targetOp.getRegion(), newTargetOp.getRegion(),
                               newTargetOp.getRegion().begin());
   rewriter.replaceOp(targetOp, targetDataOp);
@@ -1486,8 +1485,7 @@ genPreTargetOp(omp::TargetOp targetOp, SmallVector<Value> &preMapOperands,
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
       targetOp.getIsDevicePtrVars(), preMapOperands, targetOp.getNowaitAttr(),
       targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
-      targetOp.getPrivateNeedsBarrierAttr(),
-      targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+      targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVals(),
       targetOp.getPrivateMapsAttr());
   auto *preTargetBlock = rewriter.createBlock(
       &preTargetOp.getRegion(), preTargetOp.getRegion().begin(), {}, {});
@@ -1577,8 +1575,7 @@ genIsolatedTargetOp(omp::TargetOp targetOp, SmallVector<Value> &postMapOperands,
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
       targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
       targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
-      targetOp.getPrivateNeedsBarrierAttr(),
-      targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+      targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVals(),
       targetOp.getPrivateMapsAttr());
   auto *isolatedTargetBlock =
       rewriter.createBlock(&isolatedTargetOp.getRegion(),
@@ -1658,8 +1655,7 @@ static omp::TargetOp genPostTargetOp(omp::TargetOp targetOp,
       targetOp.getInReductionByrefAttr(), targetOp.getInReductionSymsAttr(),
       targetOp.getIsDevicePtrVars(), postMapOperands, targetOp.getNowaitAttr(),
       targetOp.getPrivateVars(), targetOp.getPrivateSymsAttr(),
-      targetOp.getPrivateNeedsBarrierAttr(),
-      targetOp.getThreadLimitNumDimsAttr(), targetOp.getThreadLimitDimsValues(),
+      targetOp.getPrivateNeedsBarrierAttr(), targetOp.getThreadLimitVals(),
       targetOp.getPrivateMapsAttr());
   // Create the block for postTargetOp
   auto *postTargetBlock = rewriter.createBlock(
diff --git a/flang/test/Lower/OpenMP/teams.f90 b/flang/test/Lower/OpenMP/teams.f90
index 47d379d6c2842..e5ba7070cf664 100644
--- a/flang/test/Lower/OpenMP/teams.f90
+++ b/flang/test/Lower/OpenMP/teams.f90
@@ -21,7 +21,7 @@ subroutine teams_numteams(num_teams)
   integer, intent(inout) :: num_teams
 
   ! CHECK: omp.teams
-  ! CHECK-SAME: num_teams( to %{{.*}}: i32)
+  ! CHECK-SAME: num_teams(to %{{.*}}: i32)
   !$omp teams num_teams(4)
   ! CHECK: fir.call
   call f1()
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 3c394dc67cbe1..0a5fd0e90366b 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1452,52 +1452,48 @@ class OpenMP_ThreadLimitClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
-    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$thread_limit_num_dims,
-    Variadic<IntLikeType>:$thread_limit_dims_values
+    Variadic<IntLikeType>:$thread_limit_vals
   );
 
   let optAssemblyFormat = [{
     `thread_limit` `(` custom<ThreadLimitClause>(
-      $thread_limit_num_dims, $thread_limit_dims_values, type($thread_limit_dims_values)
+      $thread_limit_vals, type($thread_limit_vals)
     ) `)`
   }];
 
   let description = [{
     The `thread_limit` clause specifies the limit on the number of threads.
 
-    With dims modifier:
-    - The number of dimensions is specified by the `thread_limit_num_dims`.
-    - The values for each dimension are specified by the `thread_limit_dims_values`.
-    - Format: `thread_limit(dims(N): values : type)`
-    - Example: `thread_limit(dims(2): %n, %m : i64)`
+    Multi-dimensional format (dims modifier):
+    - Multiple values can be specified for multi-dimensional thread limits.
+    - The number of dimensions is derived from the number of values.
+    - Values can have different integer types.
+    - Format: `thread_limit(%v1, %v2, ... : type1, type2, ...)`
+    - Example: `thread_limit(%n, %m : i32, i64)`
 
-    Without dims modifier:
-    - The number of threads is specified by the single value in `thread_limit_dims_values`.
-    - Format: `thread_limit(value : type)`
-    - Example: `thread_limit(%n : i64)`
+    Single value format:
+    - A single value specifies the thread limit.
+    - Format: `thread_limit(%value : type)`
+    - Example: `thread_limit(%n : i32)`
   }];
 
   let extraClassDeclaration = [{
-    /// Returns true if the dims modifier is explicitly present
-    bool hasThreadLimitDimsModifier() {
-      return getThreadLimitNumDims().has_value() && getThreadLimitNumDims().value();
+    /// Returns true if using multi-dimensional values (more than one value)
+    bool hasThreadLimitMultiDim() {
+      return getThreadLimitVals().size() > 1;
     }
 
-    /// Returns the number of dimensions specified by dims modifier
+    /// Returns the number of dimensions specified for thread_limit
     unsigned getThreadLimitDimsCount() {
-      if (!hasThreadLimitDimsModifier())
-        return 1;
-      return static_cast<unsigned>(*getThreadLimitNumDims());
+      return getThreadLimitVals().size();
     }
 
     /// Returns the value for a specific dimension index
-    /// Index must be less than getThreadLimitDimsCount()
-    ::mlir::Value getThreadLimitDimensionValue(unsigned index) {
-      assert(index < getThreadLimitDimsCount() &&
-             "Thread limit dims index out of bounds");
-      if (getThreadLimitDimsValues().empty())
-        return nullptr;
-      return getThreadLimitDimsValues()[index];
+    /// Index must be less than getThreadLimitVals().size()
+    ::mlir::Value getThreadLimitVal(unsigned index) {
+      assert(index < getThreadLimitVals().size() &&
+             "Thread limit index out of bounds");
+      return getThreadLimitVals()[index];
     }
   }];
 }
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 7f3e0d740ffcc..a5d36a13129b7 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2210,22 +2210,10 @@ void TargetOp::build(OpBuilder &builder, OperationState &state,
                   /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
                   clauses.mapVars, clauses.nowait, clauses.privateVars,
                   makeArrayAttr(ctx, clauses.privateSyms),
-                  clauses.privateNeedsBarrier, clauses.threadLimitNumDims,
-                  clauses.threadLimitDimsValues,
+                  clauses.privateNeedsBarrier, clauses.threadLimitVals,
                   /*private_maps=*/nullptr);
 }
 
-// helper for thread_limit clause restrictions
-static LogicalResult
-verifyThreadLimitClause(Operation *op,
-                        std::optional<IntegerAttr> threadLimitNumDims,
-                        OperandRange threadLimitDimsValues) {
-  if (failed(verifyDimsModifier(op, threadLimitNumDims, threadLimitDimsValues)))
-    return failure();
-
-  return success();
-}
-
 LogicalResult TargetOp::verify() {
   if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
     return failure();
@@ -2237,10 +2225,6 @@ LogicalResult TargetOp::verify() {
   if (failed(verifyMapClause(*this, getMapVars())))
     return failure();
 
-  if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
-                                     getThreadLimitDimsValues())))
-    return failure();
-
   return verifyPrivateVarsMapping(*this);
 }
 
@@ -2258,7 +2242,7 @@ LogicalResult TargetOp::verifyRegions() {
       if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
         if (teamsOp.getNumTeamsLower() == hostEvalArg ||
             teamsOp.getNumTeamsUpper() == hostEvalArg ||
-            llvm::is_contained(teamsOp.getThreadLimitDimsValues(), hostEvalArg))
+            llvm::is_contained(teamsOp.getThreadLimitVals(), hostEvalArg))
           continue;
 
         return emitOpError() << "host_eval argument only legal as 'num_teams' "
@@ -2652,7 +2636,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
                  clauses.reductionVars,
                  makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
                  makeArrayAttr(ctx, clauses.reductionSyms),
-                 clauses.threadLimit);
+                 clauses.threadLimitVals);
 }
 
 // Verify num_teams clause
@@ -2707,11 +2691,6 @@ LogicalResult TeamsOp::verify() {
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
-  // Check for thread_limit clause restrictions
-  if (failed(verifyThreadLimitClause(*this, getThreadLimitNumDimsAttr(),
-                                     getThreadLimitDimsValues())))
-    return failure();
-
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());
 }
@@ -4636,7 +4615,7 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
       p << " : " << upperBoundType;
     } else {
       // Upper only: to upper : type
-      p << "to ";
+      p << " to ";
       p.printOperand(upperBound);
       p << " : " << upperBoundType;
     }
@@ -4647,31 +4626,24 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
 // Parser and printer for thread_limit clause
 //===----------------------------------------------------------------------===//
 static ParseResult
-parseThreadLimitClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+parseThreadLimitClause(OpAsmParser &parser,
                        SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
                        SmallVectorImpl<Type> &types) {
-  // Try parsing with dims modifier: dims(N): values : type
-  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
-    return success();
-  }
-
-  // Without dims modifier: value : type
-  OpAsmParser::UnresolvedOperand singleValue;
-  Type singleType;
-  if (parser.parseOperand(singleValue) || parser.parseColon() ||
-      parser.parseType(singleType)) {
+  // Parse comma-separated list of values with their types
+  // Format: %v1, %v2, ... : type1, type2, ...
+  if (parser.parseOperandList(values) || parser.parseColon() ||
+      parser.parseTypeList(types)) {
     return failure();
   }
-  values.push_back(singleValue);
-  types.push_back(singleType);
   return success();
 }
 
 static void printThreadLimitClause(OpAsmPrinter &p, Operation *op,
-                                   IntegerAttr dimsAttr, OperandRange values,
-                                   TypeRange types) {
-  // Multidimensional: dims(N): values : type
-  printDimsModifierWithValues(p, dimsAttr, values, types);
+                                   OperandRange values, TypeRange types) {
+  // Print values with their types
+  llvm::interleaveComma(values, p, [&](Value v) { p << v; });
+  p << " : ";
+  llvm::interleaveComma(types, p, [&](Type t) { p << t; });
 }
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index d491925505abb..725d2d4345b3d 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,6 +380,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.hasNumTeamsMultiDim())
       result = todo("num_teams with multi-dimensional values");
   };
+  auto checkThreadLimitMultiDim = [&todo](auto op, LogicalResult &result) {
+    if (op.hasThreadLimitMultiDim())
+      result = todo("thread_limit with multi-dimensional values");
+  };
 
   LogicalResult result = success();
   llvm::TypeSwitch<Operation &>(op)
@@ -404,7 +408,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       .Case([&](omp::TeamsOp op) {
         checkAllocate(op, result);
         checkPrivate(op, result);
-        checkNumTeams(op, result);
+        checkNumTeamsMultiDim(op, result);
+        checkThreadLimitMultiDim(op, result);
       })
       .Case([&](omp::TaskOp op) {
         checkAllocate(op, result);
@@ -442,6 +447,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkAllocate(op, result);
         checkBare(op, result);
         checkInReduction(op, result);
+        checkThreadLimitMultiDim(op, result);
       })
       .Default([](Operation &) {
         // Assume all clauses for an operation can be translated unless they are
@@ -2075,8 +2081,8 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
     numTeamsUpper = moduleTranslation.lookupValue(numTeamsUpperVar);
 
   llvm::Value *threadLimit = nullptr;
-  if (Value threadLimitVar = op.getThreadLimitDimensionValue(0))
-    threadLimit = moduleTranslation.lookupValue(threadLimitVar);
+  if (!op.getThreadLimitVals().empty())
+    threadLimit = moduleTranslation.lookupValue(op.getThreadLimitVal(0));
 
   llvm::Value *ifExpr = nullptr;
   if (Value ifVar = op.getIfExpr())
@@ -6045,7 +6051,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               numTeamsLower = hostEvalVar;
             else if (teamsOp.getNumTeamsUpper() == blockArg)
               numTeamsUpper = hostEvalVar;
-            else if (teamsOp.getThreadLimitDimensionValue(0) == blockArg)
+            else if (!teamsOp.getThreadLimitVals().empty() &&
+                     teamsOp.getThreadLimitVal(0) == blockArg)
               threadLimit = hostEvalVar;
             else
               llvm_unreachable("unsupported host_eval use");
@@ -6165,7 +6172,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     if (auto teamsOp = castOrGetParentOfType<omp::TeamsOp>(capturedOp)) {
       numTeamsLower = teamsOp.getNumTeamsLower();
       numTeamsUpper = teamsOp.getNumTeamsUpper();
-      threadLimit = teamsOp.getThreadLimitDimensionValue(0);
+      if (!teamsOp.getThreadLimitVals().empty())
+        threadLimit = teamsOp.getThreadLimitVal(0);
     }
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
@@ -6210,8 +6218,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
 
   // Extract 'thread_limit' clause from 'target' and 'teams' directives.
   int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
-  setMaxValueFromClause(targetOp.getThreadLimitDimensionValue(0),
-                        targetThreadLimitVal);
+  if (!targetOp.getThreadLimitVals().empty())
+    setMaxValueFromClause(targetOp.getThreadLimitVal(0), targetThreadLimitVal);
   setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
 
   // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
@@ -6290,9 +6298,11 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
                          teamsThreadLimit, &lowerBounds, &upperBounds, &steps);
 
   // TODO: Handle constant 'if' clauses.
-  if (Value targetThreadLimit = targetOp.getThreadLimitDimensionValue(0))
+  if (!targetOp.getThreadLimitVals().empty()) {
+    Value targetThreadLimit = targetOp.getThreadLimitVal(0);
     attrs.TargetThreadLimit.front() =
         moduleTranslation.lookupValue(targetThreadLimit);
+  }
 
   if (numTeamsLower)
     attrs.MinTeams = moduleTranslation.lookupValue(numTeamsLower);
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 14ee2948d3634..d451b14e8bfc9 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -1489,139 +1489,6 @@ func.func @omp_teams_num_teams2(%lb : i32, %ub : i16) {
 
 // -----
 
-func.func @omp_teams_thread_limit_dims_mismatch() {
-  omp.target {
-    %v0 = arith.constant 1 : i32
-    %v1 = arith.constant 2 : i32
-    // expected-error @below {{dims(3) specified but 2 values provided}}
-    "omp.teams" (%v0, %v1) ({
-      omp.terminator
-    }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_teams_thread_limit_dims_with_scalar() {
-  omp.target {
-    %v0 = arith.constant 1 : i32
-    %v1 = arith.constant 2 : i32
-    %tl = arith.constant 4 : i32
-    // expected-error @below {{dims(2) specified but 3 values provided}}
-    "omp.teams" (%v0, %v1, %tl) ({
-      omp.terminator
-    }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> ()
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_teams_thread_limit_dims_no_values() {
-  omp.target {
-    // expected-error @below {{dims modifier requires values to be specified}}
-    "omp.teams" () ({
-      omp.terminator
-    }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0>} : () -> ()
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_teams_thread_limit_values_without_dims() {
-  omp.target {
-    %v0 = arith.constant 1 : i32
-    %v1 = arith.constant 2 : i32
-    // expected-error @below {{dims values can only be specified with dims modifier}}
-    "omp.teams" (%v0, %v1) ({
-      omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_teams_thread_limit_dims_type_mismatch() {
-  omp.target {
-    %v0 = arith.constant 1 : i32
-    %v1 = arith.constant 2 : i64
-    // expected-error @below {{dims modifier requires all values to have the same type}}
-    "omp.teams" (%v0, %v1) ({
-      omp.terminator
-    }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,2>} : (i32, i64) -> ()
-    omp.terminator
-  }
-  return
-}
-
-// -----
-
-func.func @omp_target_thread_limit_dims_mismatch() {
-  %v0 = arith.constant 1 : i32
-  %v1 = arith.constant 2 : i32
-  // expected-error @below {{dims(3) specified but 2 values provided}}
-  "omp.target" (%v0, %v1) ({
-    omp.terminator
-  }) {thread_limit_num_dims = 3 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
-  return
-}
-
-// -----
-
-func.func @omp_target_thread_limit_dims_with_scalar() {
-  %v0 = arith.constant 1 : i32
-  %v1 = arith.constant 2 : i32
-  %tl = arith.constant 4 : i32
-  // expected-error @below {{dims(2) specified but 3 values provided}}
-  "omp.target" (%v0, %v1, %tl) ({
-    omp.terminator
-  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,3>} : (i32, i32, i32) -> ()
-  return
-}
-
-// -----
-
-func.func @omp_target_thread_limit_dims_no_values() {
-  // expected-error @below {{dims modifier requires values to be specified}}
-  "omp.target" () ({
-    omp.terminator
-  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,0>} : () -> ()
-  return
-}
-
-// -----
-
-func.func @omp_target_thread_limit_values_without_dims() {
-  %v0 = arith.constant 1 : i32
-  %v1 = arith.constant 2 : i32
-  // expected-error @below {{dims values can only be specified with dims modifier}}
-  "omp.target" (%v0, %v1) ({
-    omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i32) -> ()
-  return
-}
-
-// -----
-
-func.func @omp_target_thread_limit_dims_type_mismatch() {
-  %v0 = arith.constant 1 : i32
-  %v1 = arith.constant 2 : i64
-  // expected-error @below {{dims modifier requires all values to have the same type}}
-  "omp.target" (%v0, %v1) ({
-    omp.terminator
-  }) {thread_limit_num_dims = 2 : i64, operandSegmentSizes = array<i32: 0,0,0,0,0,0,0,0,0,0,0,2>} : (i32, i64) -> ()
-  return
-}
-
-// -----
-
 func.func @omp_sections(%data_var : memref<i32>) -> () {
   // expected-error @below {{expected equal sizes for allocate and allocator variables}}
   "omp.sections" (%data_var) ({
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 0255b1eb6f10f..965fa5a40cac0 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -1103,7 +1103,7 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
     omp.terminator
   }
 
-  // CHECK: omp.teams num_teams(to %{{.+}} : i32)
+  // CHECK: omp.teams num_teams( to %{{.+}} : i32)
   omp.teams num_teams(to %ub : i32) {
     // CHECK: omp.terminator
     omp.terminator
@@ -1136,8 +1136,15 @@ func.func @omp_teams(%lb : i32, %ub : i32, %if_cond : i1, %num_threads : i32,
     omp.terminator
   }
 
-  // CHECK: omp.teams thread_limit(dims(2): %{{.*}}, %{{.*}} : i32)
-  omp.teams thread_limit(dims(2): %lb, %ub : i32) {
+  // CHECK: omp.teams thread_limit(%{{.*}}, %{{.*}} : i32, i32)
+  omp.teams thread_limit(%lb, %ub : i32, i32) {
+    // CHECK: omp.terminator
+    omp.terminator
+  }
+
+  // Test thread_limit with mixed types.
+  // CHECK: omp.teams thread_limit(%{{.*}}, %{{.*}}, %{{.*}} : i32, i64, i16)
+  omp.teams thread_limit(%lb, %ub64, %ub16 : i32, i64, i16) {
     // CHECK: omp.terminator
     omp.terminator
   }
@@ -3090,7 +3097,7 @@ func.func @omp_target_private_with_map_idx(%map1: memref<?xi32>, %map2: memref<?
 
 func.func @omp_target_host_eval(%x : i32) {
   // CHECK: omp.target host_eval(%{{.*}} -> %[[HOST_ARG:.*]] : i32) {
-  // CHECK: omp.teams num_teams(to %[[HOST_ARG]] : i32)
+  // CHECK: omp.teams num_teams( to %[[HOST_ARG]] : i32)
   // CHECK-SAME: thread_limit(%[[HOST_ARG]] : i32)
   omp.target host_eval(%x -> %arg0 : i32) {
     omp.teams num_teams( to %arg0 : i32) thread_limit(%arg0 : i32) {
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 1ea56fdd0bf16..6e85b67796312 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -443,6 +443,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
 
 // -----
 
+llvm.func @teams_thread_limit_multi_dim(%lb : i32, %ub : i32) {
+  // expected-error at below {{not yet implemented: Unhandled clause thread_limit with multi-dimensional values in omp.teams operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.teams}}
+  omp.teams thread_limit(%lb, %ub : i32, i32) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
   // expected-error at below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}

>From bcc8484bd734697c5158c36368d26455ccb0231e Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Sat, 17 Jan 2026 10:21:49 +0530
Subject: [PATCH 10/10] remove custom parser/printer for dims

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 86 ++-----------------
 mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td |  3 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 29 -------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 18 ++--
 4 files changed, 16 insertions(+), 120 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 0a5fd0e90366b..c70ee999a8153 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1456,9 +1456,7 @@ class OpenMP_ThreadLimitClauseSkip<
   );
 
   let optAssemblyFormat = [{
-    `thread_limit` `(` custom<ThreadLimitClause>(
-      $thread_limit_vals, type($thread_limit_vals)
-    ) `)`
+    `thread_limit` `(` $thread_limit_vals `:` type($thread_limit_vals) `)`
   }];
 
   let description = [{
@@ -1488,12 +1486,12 @@ class OpenMP_ThreadLimitClauseSkip<
       return getThreadLimitVals().size();
     }
 
-    /// Returns the value for a specific dimension index
-    /// Index must be less than getThreadLimitVals().size()
-    ::mlir::Value getThreadLimitVal(unsigned index) {
-      assert(index < getThreadLimitVals().size() &&
+    /// Returns the value for a specific dimension
+    /// dim must be less than getThreadLimitDimsCount()
+    ::mlir::Value getThreadLimit(unsigned dim = 0) {
+      assert(dim < getThreadLimitDimsCount() &&
              "Thread limit index out of bounds");
-      return getThreadLimitVals()[index];
+      return getThreadLimitVals()[dim];
     }
   }];
 }
@@ -1601,76 +1599,4 @@ class OpenMP_UseDevicePtrClauseSkip<
 
 def OpenMP_UseDevicePtrClause : OpenMP_UseDevicePtrClauseSkip<>;
 
-//===----------------------------------------------------------------------===//
-// V6.2: Multidimensional `num_teams` clause with dims modifier
-//===----------------------------------------------------------------------===//
-
-class OpenMP_NumTeamsMultiDimClauseSkip<
-    bit traits = false, bit arguments = false, bit assemblyFormat = false,
-    bit description = false, bit extraClassDeclaration = false
-  > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
-            extraClassDeclaration> {
-  let arguments = (ins
-    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_teams_dims,
-    Variadic<AnyInteger>:$num_teams_values
-  );
-
-  let optAssemblyFormat = [{
-    `num_teams_multi_dim` `(` custom<NumTeamsMultiDimClause>($num_teams_dims,
-                                                    $num_teams_values,
-                                                    type($num_teams_values)) `)`
-  }];
-
-  let description = [{
-    The `num_teams_multi_dim` clause with dims modifier support specifies the limit on
-    the number of teams to be created in a multidimensional team space.
-
-    The dims modifier for the num_teams_multi_dim clause specifies the number of
-    dimensions for the league space (team space) that the clause arranges.
-    The dimensions argument in the dims modifier specifies the number of
-    dimensions and determines the length of the list argument. The list items
-    are specified in ascending order according to the ordinal number of the
-    dimensions (dimension 0, 1, 2, ..., N-1).
-
-    - If `dims` is not specified: The space is unidimensional (1D) with a single value
-    - If `dims(1)` is specified: The space is explicitly unidimensional (1D)
-    - If `dims(N)` where N > 1: The space is strictly multidimensional (N-D)
-
-    **Examples:**
-    - `num_teams_multi_dim(dims(3): %nt0, %nt1, %nt2 : i32, i32, i32)` creates a
-      3-dimensional team space with limits nt0, nt1, nt2 for dimensions 0, 1, 2.
-    - `num_teams_multi_dim(%nt : i32)` creates a unidimensional team space with limit nt.
-  }];
-
-  let extraClassDeclaration = [{
-    /// Returns true if the dims modifier is explicitly present
-    bool hasDimsModifier() {
-      return getNumTeamsDims().has_value();
-    }
-
-    /// Returns the number of dimensions specified by dims modifier
-    /// Returns 1 if dims modifier is not present (unidimensional by default)
-    unsigned getNumDimensions() {
-      if (!hasDimsModifier())
-        return 1;
-      return static_cast<unsigned>(*getNumTeamsDims());
-    }
-
-    /// Returns all dimension values as an operand range
-    ::mlir::OperandRange getDimensionValues() {
-      return getNumTeamsValues();
-    }
-
-    /// Returns the value for a specific dimension index
-    /// Index must be less than getNumDimensions()
-    ::mlir::Value getDimensionValue(unsigned index) {
-    assert(index < getDimensionValues().size() &&
-        "Dimension index out of bounds");
-      return getDimensionValues()[index];
-    }
-  }];
-}
-
-def OpenMP_NumTeamsMultiDimClause : OpenMP_NumTeamsMultiDimClauseSkip<>;
-
 #endif // OPENMP_CLAUSES
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index 76eeb0bd70ec3..f8cb300d028d0 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -241,8 +241,7 @@ def TeamsOp : OpenMP_Op<"teams", traits = [
     AttrSizedOperandSegments, RecursiveMemoryEffects, OutlineableOpenMPOpInterface
   ], clauses = [
     OpenMP_AllocateClause, OpenMP_IfClause, OpenMP_NumTeamsClause,
-    OpenMP_NumTeamsMultiDimClause, OpenMP_PrivateClause, OpenMP_ReductionClause,
-    OpenMP_ThreadLimitClause
+    OpenMP_PrivateClause, OpenMP_ReductionClause,OpenMP_ThreadLimitClause
   ], singleRegion = true> {
   let summary = "teams construct";
   let description = [{
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index a5d36a13129b7..3238c7d4145e6 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2624,13 +2624,8 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
   TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-<<<<<<< HEAD
                  clauses.ifExpr, clauses.numTeamsVals, clauses.numTeamsLower,
                  clauses.numTeamsUpper,
-=======
-                 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
-                 clauses.numTeamsDims, clauses.numTeamsValues,
->>>>>>> [OpenMP][MLIR] Add num_teams clause with dims modifier support
                  /*private_vars=*/{}, /*private_syms=*/nullptr,
                  /*private_needs_barrier=*/nullptr, clauses.reductionMod,
                  clauses.reductionVars,
@@ -4622,30 +4617,6 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
   }
 }
 
-//===----------------------------------------------------------------------===//
-// Parser and printer for thread_limit clause
-//===----------------------------------------------------------------------===//
-static ParseResult
-parseThreadLimitClause(OpAsmParser &parser,
-                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-                       SmallVectorImpl<Type> &types) {
-  // Parse comma-separated list of values with their types
-  // Format: %v1, %v2, ... : type1, type2, ...
-  if (parser.parseOperandList(values) || parser.parseColon() ||
-      parser.parseTypeList(types)) {
-    return failure();
-  }
-  return success();
-}
-
-static void printThreadLimitClause(OpAsmPrinter &p, Operation *op,
-                                   OperandRange values, TypeRange types) {
-  // Print values with their types
-  llvm::interleaveComma(values, p, [&](Value v) { p << v; });
-  p << " : ";
-  llvm::interleaveComma(types, p, [&](Type t) { p << t; });
-}
-
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 725d2d4345b3d..6a2ceac92eb09 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,7 +380,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.hasNumTeamsMultiDim())
       result = todo("num_teams with multi-dimensional values");
   };
-  auto checkThreadLimitMultiDim = [&todo](auto op, LogicalResult &result) {
+  auto checkThreadLimit = [&todo](auto op, LogicalResult &result) {
     if (op.hasThreadLimitMultiDim())
       result = todo("thread_limit with multi-dimensional values");
   };
@@ -408,8 +408,8 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       .Case([&](omp::TeamsOp op) {
         checkAllocate(op, result);
         checkPrivate(op, result);
-        checkNumTeamsMultiDim(op, result);
-        checkThreadLimitMultiDim(op, result);
+        checkNumTeams(op, result);
+        checkThreadLimit(op, result);
       })
       .Case([&](omp::TaskOp op) {
         checkAllocate(op, result);
@@ -447,7 +447,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
         checkAllocate(op, result);
         checkBare(op, result);
         checkInReduction(op, result);
-        checkThreadLimitMultiDim(op, result);
+        checkThreadLimit(op, result);
       })
       .Default([](Operation &) {
         // Assume all clauses for an operation can be translated unless they are
@@ -2082,7 +2082,7 @@ convertOmpTeams(omp::TeamsOp op, llvm::IRBuilderBase &builder,
 
   llvm::Value *threadLimit = nullptr;
   if (!op.getThreadLimitVals().empty())
-    threadLimit = moduleTranslation.lookupValue(op.getThreadLimitVal(0));
+    threadLimit = moduleTranslation.lookupValue(op.getThreadLimit(0));
 
   llvm::Value *ifExpr = nullptr;
   if (Value ifVar = op.getIfExpr())
@@ -6052,7 +6052,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
             else if (teamsOp.getNumTeamsUpper() == blockArg)
               numTeamsUpper = hostEvalVar;
             else if (!teamsOp.getThreadLimitVals().empty() &&
-                     teamsOp.getThreadLimitVal(0) == blockArg)
+                     teamsOp.getThreadLimit(0) == blockArg)
               threadLimit = hostEvalVar;
             else
               llvm_unreachable("unsupported host_eval use");
@@ -6173,7 +6173,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
       numTeamsLower = teamsOp.getNumTeamsLower();
       numTeamsUpper = teamsOp.getNumTeamsUpper();
       if (!teamsOp.getThreadLimitVals().empty())
-        threadLimit = teamsOp.getThreadLimitVal(0);
+        threadLimit = teamsOp.getThreadLimit(0);
     }
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
@@ -6219,7 +6219,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
   // Extract 'thread_limit' clause from 'target' and 'teams' directives.
   int32_t targetThreadLimitVal = -1, teamsThreadLimitVal = -1;
   if (!targetOp.getThreadLimitVals().empty())
-    setMaxValueFromClause(targetOp.getThreadLimitVal(0), targetThreadLimitVal);
+    setMaxValueFromClause(targetOp.getThreadLimit(0), targetThreadLimitVal);
   setMaxValueFromClause(threadLimit, teamsThreadLimitVal);
 
   // Extract 'max_threads' clause from 'parallel' or set to 1 if it's SIMD.
@@ -6299,7 +6299,7 @@ initTargetRuntimeAttrs(llvm::IRBuilderBase &builder,
 
   // TODO: Handle constant 'if' clauses.
   if (!targetOp.getThreadLimitVals().empty()) {
-    Value targetThreadLimit = targetOp.getThreadLimitVal(0);
+    Value targetThreadLimit = targetOp.getThreadLimit(0);
     attrs.TargetThreadLimit.front() =
         moduleTranslation.lookupValue(targetThreadLimit);
   }



More information about the llvm-branch-commits mailing list