[llvm-branch-commits] [flang] [mlir] [OpenMP][MLIR] Add num_threads clause with dims modifier support (PR #171767)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Jan 16 21:08:09 PST 2026


https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/171767

>From 6093bdcf18e36ad0ef1b97c6c2cac8b8cd9000c3 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 11:56:58 +0530
Subject: [PATCH 1/7] [OpenMP][MLIR] Add num_threads clause with dims modifier
 support

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 50 +++++++++++-
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |  2 +
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 79 +++++++++++++++++--
 mlir/test/Dialect/OpenMP/invalid.mlir         | 33 +++++++-
 mlir/test/Dialect/OpenMP/ops.mlir             | 15 ++--
 5 files changed, 163 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index d4640f254ed1f..aedfa05da1608 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,16 +1069,60 @@ class OpenMP_NumThreadsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
+    Variadic<AnyInteger>:$num_threads_values,
     Optional<IntLikeType>:$num_threads
   );
 
   let optAssemblyFormat = [{
-    `num_threads` `(` $num_threads `:` type($num_threads) `)`
+    `num_threads` `(` custom<NumThreadsClause>(
+      $num_threads_dims, $num_threads_values, type($num_threads_values),
+      $num_threads, type($num_threads)
+    ) `)`
   }];
 
   let description = [{
-    The optional `num_threads` parameter specifies the number of threads which
-    should be used to execute the parallel region.
+    num_threads clause specifies the desired number of threads in the team
+    space formed by the construct on which it appears.
+
+    With dims modifier:
+    - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list)
+    - Specifies upper bounds for each dimension (all must have same type)
+    - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
+    - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
+
+    Without dims modifier:
+    - Uses `num_threads`
+    - If lower bound not specified, it defaults to upper bound value
+    - Format: `num_threads(bounds : type)`
+    - Example: `num_threads(%ub : i32)`
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if the dims modifier is explicitly present
+    bool hasDimsModifier() {
+      return getNumThreadsDims().has_value();
+    }
+
+    /// Returns the number of dimensions specified by dims modifier
+    unsigned getNumDimensions() {
+      if (!hasDimsModifier())
+        return 1;
+      return static_cast<unsigned>(*getNumThreadsDims());
+    }
+
+    /// Returns all dimension values as an operand range
+    ::mlir::OperandRange getDimensionValues() {
+      return getNumThreadsValues();
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumDimensions()
+    ::mlir::Value getDimensionValue(unsigned index) {
+      assert(index < getDimensionValues().size() &&
+             "Dimension index out of bounds");
+      return getDimensionValues()[index];
+    }
   }];
 }
 
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 6423d49859c97..0d5333ec2e455 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,6 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
+        /* num_threads_dims = */ nullptr,
+        /* num_threads_values = */ llvm::SmallVector<Value>{},
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 67ff9023a38da..9664b8f59802c 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2504,6 +2504,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        ArrayRef<NamedAttribute> attributes) {
   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
+                    /*num_threads_dims=*/nullptr,
+                    /*num_threads_values=*/ValueRange(),
                     /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
                     /*proc_bind_kind=*/nullptr,
@@ -2515,13 +2517,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
 void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                    clauses.ifExpr, clauses.numThreads, clauses.privateVars,
-                    makeArrayAttr(ctx, clauses.privateSyms),
-                    clauses.privateNeedsBarrier, clauses.procBindKind,
-                    clauses.reductionMod, clauses.reductionVars,
-                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-                    makeArrayAttr(ctx, clauses.reductionSyms));
+  ParallelOp::build(
+      builder, state, clauses.allocateVars, clauses.allocatorVars,
+      clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+      clauses.numThreads, clauses.privateVars,
+      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
+      clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
+      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+      makeArrayAttr(ctx, clauses.reductionSyms));
 }
 
 template <typename OpType>
@@ -2568,13 +2571,40 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
 }
 
 LogicalResult ParallelOp::verify() {
+  // verify num_threads clause restrictions
+  auto numThreadsDims = getNumThreadsDims();
+  auto numThreadsValues = getNumThreadsValues();
+  auto numThreads = getNumThreads();
+
+  // num_threads with dims modifier
+  if (numThreadsDims.has_value() && numThreadsValues.empty()) {
+    return emitError(
+        "num_threads dims modifier requires values to be specified");
+  }
+
+  if (numThreadsDims.has_value() &&
+      numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
+    return emitError("num_threads dims(")
+           << *numThreadsDims << ") specified but " << numThreadsValues.size()
+           << " values provided";
+  }
+
+  // num_threads dims and number of threads cannot be used together
+  if (numThreadsDims.has_value() && numThreads) {
+    return emitError(
+        "num_threads dims and number of threads cannot be used together");
+  }
+
+  // verify allocate clause restrictions
   if (getAllocateVars().size() != getAllocatorVars().size())
     return emitError(
         "expected equal sizes for allocate and allocator variables");
 
+  // verify private variables restrictions
   if (failed(verifyPrivateVarList(*this)))
     return failure();
 
+  // verify reduction variables restrictions
   return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
                                 getReductionByref());
 }
@@ -4595,6 +4625,41 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
   }
 }
 
+//===----------------------------------------------------------------------===//
+// Parser and printer for num_threads clause
+//===----------------------------------------------------------------------===//
+static ParseResult
+parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
+                      SmallVectorImpl<Type> &types,
+                      std::optional<OpAsmParser::UnresolvedOperand> &bounds,
+                      Type &boundsType) {
+  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
+    return success();
+  }
+
+  OpAsmParser::UnresolvedOperand boundsOperand;
+  if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
+      parser.parseType(boundsType)) {
+    return failure();
+  }
+  bounds = boundsOperand;
+  return success();
+}
+
+static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
+                                  IntegerAttr dimsAttr, OperandRange values,
+                                  TypeRange types, Value bounds,
+                                  Type boundsType) {
+  if (!values.empty()) {
+    printDimsModifierWithValues(p, dimsAttr, values, types);
+  }
+  if (bounds) {
+    p.printOperand(bounds);
+    p << " : " << boundsType;
+  }
+}
+
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index bb882db73cbab..75431ec475954 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,6 +30,37 @@ func.func @num_threads_once(%n : si32) {
 
 // -----
 
+func.func @num_threads_dims_no_values() {
+  // expected-error at +1 {{num_threads dims modifier requires values to be specified}}
+  "omp.parallel"() ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> ()
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_mismatch(%n : i64) {
+  // expected-error at +1 {{num_threads dims(2) specified but 1 values provided}}
+  omp.parallel num_threads(dims(2): %n : i64) {
+    omp.terminator
+  }
+
+  return
+}
+
+// -----
+
+func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
+  // expected-error at +1 {{num_threads dims and number of threads cannot be used together}}
+  "omp.parallel"(%n, %n, %m) ({
+    omp.terminator
+  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> ()
+  return
+}
+
+// -----
+
 func.func @nowait_not_allowed(%n : memref<i32>) {
   // expected-error at +1 {{expected '{' to begin a region}}
   omp.parallel nowait {}
@@ -2708,7 +2739,7 @@ func.func @undefined_privatizer(%arg0: index) {
 // -----
 func.func @undefined_privatizer(%arg0: !llvm.ptr) {
   // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
-  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
     ^bb0(%arg2: !llvm.ptr):
       omp.terminator
     }) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 89c7e5fd48bd9..3acbe010c28a5 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
     "omp.parallel"(%data_var, %data_var, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
     "omp.parallel"(%data_var, %data_var, %if_cond) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
 
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
 
   // CHECK: omp.parallel
   omp.parallel {
@@ -160,6 +160,11 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
    omp.terminator
  }
 
+ // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
+ omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+   omp.terminator
+ }
+
  // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
  omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
    omp.terminator

>From 97045e6201626b5f73e5178905a9a2cefa09b9cf Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 12:11:49 +0530
Subject: [PATCH 2/7] Mark mlir->llvmir translation for num_threads with dims
 as NYI

---
 .../Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp  | 15 ++++++++++++++-
 1 file changed, 14 insertions(+), 1 deletion(-)

diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 8a3a990e5a3fd..e66666b526069 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3268,6 +3268,10 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   if (auto ifVar = opInst.getIfExpr())
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
+  // num_threads dims and values are not yet supported
+  assert(!opInst.getNumThreadsDims().has_value() &&
+         opInst.getNumThreadsValues().empty() &&
+         "Lowering of num_threads with dims modifier is NYI.");
   if (auto numThreadsVar = opInst.getNumThreads())
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
@@ -6050,6 +6054,10 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
+            // num_threads dims and values are not yet supported
+            assert(!parallelOp.getNumThreadsDims().has_value() &&
+                   parallelOp.getNumThreadsValues().empty() &&
+                   "Lowering of num_threads with dims modifier is NYI.");
             if (parallelOp.getNumThreads() == blockArg)
               numThreads = hostEvalVar;
             else
@@ -6167,8 +6175,13 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
       threadLimit = teamsOp.getThreadLimit();
     }
 
-    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
+    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+      // num_threads dims and values are not yet supported
+      assert(!parallelOp.getNumThreadsDims().has_value() &&
+             parallelOp.getNumThreadsValues().empty() &&
+             "Lowering of num_threads with dims modifier is NYI.");
       numThreads = parallelOp.getNumThreads();
+    }
   }
 
   // Handle clauses impacting the number of teams.

>From 60288588459e658d9d2d1238569a19f34e932b80 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Thu, 11 Dec 2025 17:37:52 +0530
Subject: [PATCH 3/7] few more fixes

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 33 ++++++--------
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |  4 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 44 +++++++++----------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      |  9 ++--
 mlir/test/Dialect/OpenMP/invalid.mlir         | 10 ++---
 5 files changed, 45 insertions(+), 55 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index aedfa05da1608..3559002c6473f 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,14 +1069,14 @@ class OpenMP_NumThreadsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
-    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_dims,
-    Variadic<AnyInteger>:$num_threads_values,
+    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
+    Variadic<AnyInteger>:$num_threads_dims_values,
     Optional<IntLikeType>:$num_threads
   );
 
   let optAssemblyFormat = [{
     `num_threads` `(` custom<NumThreadsClause>(
-      $num_threads_dims, $num_threads_values, type($num_threads_values),
+      $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
       $num_threads, type($num_threads)
     ) `)`
   }];
@@ -1086,7 +1086,7 @@ class OpenMP_NumThreadsClauseSkip<
     space formed by the construct on which it appears.
 
     With dims modifier:
-    - Uses `num_threads_dims` (dimension count) and `num_threads_values` (upper bounds list)
+    - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
     - Specifies upper bounds for each dimension (all must have same type)
     - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
     - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
@@ -1100,28 +1100,23 @@ class OpenMP_NumThreadsClauseSkip<
 
   let extraClassDeclaration = [{
     /// Returns true if the dims modifier is explicitly present
-    bool hasDimsModifier() {
-      return getNumThreadsDims().has_value();
+    bool hasNumThreadsDimsModifier() {
+      return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
     }
 
     /// Returns the number of dimensions specified by dims modifier
-    unsigned getNumDimensions() {
-      if (!hasDimsModifier())
+    unsigned getNumThreadsDimsCount() {
+      if (!hasNumThreadsDimsModifier())
         return 1;
-      return static_cast<unsigned>(*getNumThreadsDims());
-    }
-
-    /// Returns all dimension values as an operand range
-    ::mlir::OperandRange getDimensionValues() {
-      return getNumThreadsValues();
+      return static_cast<unsigned>(*getNumThreadsNumDims());
     }
 
     /// Returns the value for a specific dimension index
-    /// Index must be less than getNumDimensions()
-    ::mlir::Value getDimensionValue(unsigned index) {
-      assert(index < getDimensionValues().size() &&
-             "Dimension index out of bounds");
-      return getDimensionValues()[index];
+    /// Index must be less than getNumThreadsDimsCount()
+    ::mlir::Value getNumThreadsDimsValue(unsigned index) {
+      assert(index < getNumThreadsDimsCount() &&
+             "Num threads dims index out of bounds");
+      return getNumThreadsDimsValues()[index];
     }
   }];
 }
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 0d5333ec2e455..ab7bded7835be 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -448,8 +448,8 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
-        /* num_threads_dims = */ nullptr,
-        /* num_threads_values = */ llvm::SmallVector<Value>{},
+        /* num_threads_num_dims = */ nullptr,
+        /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
         /* num_threads = */ numThreadsVar,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 9664b8f59802c..54ce42f684581 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2519,7 +2519,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
   MLIRContext *ctx = builder.getContext();
   ParallelOp::build(
       builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.ifExpr, clauses.numThreadsDims, clauses.numThreadsValues,
+      clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
       clauses.numThreads, clauses.privateVars,
       makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
       clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
@@ -2570,30 +2570,28 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
   return success();
 }
 
-LogicalResult ParallelOp::verify() {
-  // verify num_threads clause restrictions
-  auto numThreadsDims = getNumThreadsDims();
-  auto numThreadsValues = getNumThreadsValues();
-  auto numThreads = getNumThreads();
-
-  // num_threads with dims modifier
-  if (numThreadsDims.has_value() && numThreadsValues.empty()) {
-    return emitError(
-        "num_threads dims modifier requires values to be specified");
-  }
-
-  if (numThreadsDims.has_value() &&
-      numThreadsValues.size() != static_cast<size_t>(*numThreadsDims)) {
-    return emitError("num_threads dims(")
-           << *numThreadsDims << ") specified but " << numThreadsValues.size()
-           << " values provided";
+// Helper: Verify num_threads clause
+LogicalResult
+verifyNumThreadsClause(Operation *op,
+                       std::optional<IntegerAttr> numThreadsNumDims,
+                       OperandRange numThreadsDimsValues, Value numThreads) {
+  bool hasDimsModifier =
+      numThreadsNumDims.has_value() && numThreadsNumDims.value();
+  if (hasDimsModifier && numThreads) {
+    return op->emitError("num_threads with dims modifier cannot be used "
+                         "together with number of threads");
   }
+  if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
+    return failure();
+  return success();
+}
 
-  // num_threads dims and number of threads cannot be used together
-  if (numThreadsDims.has_value() && numThreads) {
-    return emitError(
-        "num_threads dims and number of threads cannot be used together");
-  }
+LogicalResult ParallelOp::verify() {
+  // verify num_threads clause restrictions
+  if (failed(verifyNumThreadsClause(
+          getOperation(), this->getNumThreadsNumDimsAttr(),
+          this->getNumThreadsDimsValues(), this->getNumThreads())))
+    return failure();
 
   // verify allocate clause restrictions
   if (getAllocateVars().size() != getAllocatorVars().size())
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index e66666b526069..67f30383bb03a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3269,8 +3269,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
   // num_threads dims and values are not yet supported
-  assert(!opInst.getNumThreadsDims().has_value() &&
-         opInst.getNumThreadsValues().empty() &&
+  assert(!opInst.hasNumThreadsDimsModifier() &&
          "Lowering of num_threads with dims modifier is NYI.");
   if (auto numThreadsVar = opInst.getNumThreads())
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
@@ -6055,8 +6054,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
           })
           .Case([&](omp::ParallelOp parallelOp) {
             // num_threads dims and values are not yet supported
-            assert(!parallelOp.getNumThreadsDims().has_value() &&
-                   parallelOp.getNumThreadsValues().empty() &&
+            assert(!parallelOp.hasNumThreadsDimsModifier() &&
                    "Lowering of num_threads with dims modifier is NYI.");
             if (parallelOp.getNumThreads() == blockArg)
               numThreads = hostEvalVar;
@@ -6177,8 +6175,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
       // num_threads dims and values are not yet supported
-      assert(!parallelOp.getNumThreadsDims().has_value() &&
-             parallelOp.getNumThreadsValues().empty() &&
+      assert(!parallelOp.hasNumThreadsDimsModifier() &&
              "Lowering of num_threads with dims modifier is NYI.");
       numThreads = parallelOp.getNumThreads();
     }
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 75431ec475954..1c5ef785a17f9 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -31,17 +31,17 @@ func.func @num_threads_once(%n : si32) {
 // -----
 
 func.func @num_threads_dims_no_values() {
-  // expected-error at +1 {{num_threads dims modifier requires values to be specified}}
+  // expected-error at +1 {{dims modifier requires values to be specified}}
   "omp.parallel"() ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_dims = 2 : i64} : () -> ()
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
   return
 }
 
 // -----
 
 func.func @num_threads_dims_mismatch(%n : i64) {
-  // expected-error at +1 {{num_threads dims(2) specified but 1 values provided}}
+  // expected-error at +1 {{dims(2) specified but 1 values provided}}
   omp.parallel num_threads(dims(2): %n : i64) {
     omp.terminator
   }
@@ -52,10 +52,10 @@ func.func @num_threads_dims_mismatch(%n : i64) {
 // -----
 
 func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
-  // expected-error at +1 {{num_threads dims and number of threads cannot be used together}}
+  // expected-error at +1 {{num_threads with dims modifier cannot be used together with number of threads}}
   "omp.parallel"(%n, %n, %m) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_dims = 2 : i64} : (i64, i64, i64) -> ()
+  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
   return
 }
 

>From f07a41aa54d4f27ced44bba8b013e12b4f5ba1dd Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 19 Dec 2025 12:27:38 +0530
Subject: [PATCH 4/7] Use num_threads_dims_values only

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    |  4 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 15 ++---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 15 +++--
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |  5 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 62 ++++++++-----------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 16 ++---
 mlir/test/Dialect/OpenMP/invalid.mlir         | 12 ++--
 mlir/test/Dialect/OpenMP/ops.mlir             | 10 +--
 8 files changed, 66 insertions(+), 73 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index b923e415231d6..abaeaa90f80be 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -516,8 +516,8 @@ bool ClauseProcessor::processNumThreads(
     mlir::omp::NumThreadsClauseOps &result) const {
   if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
     // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
-    result.numThreads =
-        fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+    result.numThreadsDimsValues.push_back(
+        fir::getBase(converter.genExprValue(clause->v, stmtCtx)));
     return true;
   }
   return false;
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 989e370870f33..bdbabc292349a 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -99,8 +99,8 @@ class HostEvalInfo {
     if (ops.numTeamsUpper)
       vars.push_back(ops.numTeamsUpper);
 
-    if (ops.numThreads)
-      vars.push_back(ops.numThreads);
+    for (auto numThreads : ops.numThreadsDimsValues)
+      vars.push_back(numThreads);
 
     if (ops.threadLimit)
       vars.push_back(ops.threadLimit);
@@ -115,7 +115,8 @@ class HostEvalInfo {
     assert(args.size() ==
                ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
                    ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
-                   (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
+                   (ops.numTeamsUpper ? 1 : 0) +
+                   ops.numThreadsDimsValues.size() +
                    (ops.threadLimit ? 1 : 0) &&
            "invalid block argument list");
     int argIndex = 0;
@@ -134,8 +135,8 @@ class HostEvalInfo {
     if (ops.numTeamsUpper)
       ops.numTeamsUpper = args[argIndex++];
 
-    if (ops.numThreads)
-      ops.numThreads = args[argIndex++];
+    for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i)
+      ops.numThreadsDimsValues[i] = args[argIndex++];
 
     if (ops.threadLimit)
       ops.threadLimit = args[argIndex++];
@@ -169,13 +170,13 @@ class HostEvalInfo {
   /// \returns whether an update was performed. If not, these clauses were not
   ///          evaluated in the host device.
   bool apply(mlir::omp::ParallelOperands &clauseOps) {
-    if (!ops.numThreads || parallelApplied) {
+    if (ops.numThreadsDimsValues.empty() || parallelApplied) {
       parallelApplied = true;
       return false;
     }
 
     parallelApplied = true;
-    clauseOps.numThreads = ops.numThreads;
+    clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
     return true;
   }
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 3559002c6473f..8be7030599cc6 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1070,14 +1070,12 @@ class OpenMP_NumThreadsClauseSkip<
                     extraClassDeclaration> {
   let arguments = (ins
     ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
-    Variadic<AnyInteger>:$num_threads_dims_values,
-    Optional<IntLikeType>:$num_threads
+    Variadic<IntLikeType>:$num_threads_dims_values
   );
 
   let optAssemblyFormat = [{
     `num_threads` `(` custom<NumThreadsClause>(
-      $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values),
-      $num_threads, type($num_threads)
+      $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values)
     ) `)`
   }];
 
@@ -1092,10 +1090,9 @@ class OpenMP_NumThreadsClauseSkip<
     - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
 
     Without dims modifier:
-    - Uses `num_threads`
-    - If lower bound not specified, it defaults to upper bound value
-    - Format: `num_threads(bounds : type)`
-    - Example: `num_threads(%ub : i32)`
+    - The number of threads is specified by single value in `num_threads_dims_values`
+    - Format: `num_threads(value : type)`
+    - Example: `num_threads(%n : i32)`
   }];
 
   let extraClassDeclaration = [{
@@ -1116,6 +1113,8 @@ class OpenMP_NumThreadsClauseSkip<
     ::mlir::Value getNumThreadsDimsValue(unsigned index) {
       assert(index < getNumThreadsDimsCount() &&
              "Num threads dims index out of bounds");
+      if(getNumThreadsDimsValues().empty())
+        return nullptr;
       return getNumThreadsDimsValues()[index];
     }
   }];
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index ab7bded7835be..5d75613f9b2b6 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -438,9 +438,11 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
     rewriter.eraseOp(reduce);
 
     Value numThreadsVar;
+    SmallVector<Value> numThreadsValues;
     if (numThreads > 0) {
       numThreadsVar = LLVM::ConstantOp::create(
           rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
+      numThreadsValues.push_back(numThreadsVar);
     }
     // Create the parallel wrapper.
     auto ompParallel = omp::ParallelOp::create(
@@ -449,8 +451,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
         /* num_threads_num_dims = */ nullptr,
-        /* num_threads_dims_values = */ llvm::SmallVector<Value>{},
-        /* num_threads = */ numThreadsVar,
+        /* num_threads_dims_values = */ numThreadsValues,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
         /* private_needs_barrier = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 54ce42f684581..6911272d43f6e 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2252,7 +2252,8 @@ LogicalResult TargetOp::verifyRegions() {
       if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
         if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
             parallelOp->isAncestor(capturedOp) &&
-            hostEvalArg == parallelOp.getNumThreads())
+            llvm::is_contained(parallelOp.getNumThreadsDimsValues(),
+                               hostEvalArg))
           continue;
 
         return emitOpError()
@@ -2506,7 +2507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
                     /*num_threads_dims=*/nullptr,
                     /*num_threads_values=*/ValueRange(),
-                    /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
+                    /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
                     /*proc_bind_kind=*/nullptr,
                     /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
@@ -2517,14 +2518,14 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
 void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
-  ParallelOp::build(
-      builder, state, clauses.allocateVars, clauses.allocatorVars,
-      clauses.ifExpr, clauses.numThreadsNumDims, clauses.numThreadsDimsValues,
-      clauses.numThreads, clauses.privateVars,
-      makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
-      clauses.procBindKind, clauses.reductionMod, clauses.reductionVars,
-      makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
-      makeArrayAttr(ctx, clauses.reductionSyms));
+  ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
+                    clauses.ifExpr, clauses.numThreadsNumDims,
+                    clauses.numThreadsDimsValues, clauses.privateVars,
+                    makeArrayAttr(ctx, clauses.privateSyms),
+                    clauses.privateNeedsBarrier, clauses.procBindKind,
+                    clauses.reductionMod, clauses.reductionVars,
+                    makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
+                    makeArrayAttr(ctx, clauses.reductionSyms));
 }
 
 template <typename OpType>
@@ -2574,13 +2575,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
 LogicalResult
 verifyNumThreadsClause(Operation *op,
                        std::optional<IntegerAttr> numThreadsNumDims,
-                       OperandRange numThreadsDimsValues, Value numThreads) {
-  bool hasDimsModifier =
-      numThreadsNumDims.has_value() && numThreadsNumDims.value();
-  if (hasDimsModifier && numThreads) {
-    return op->emitError("num_threads with dims modifier cannot be used "
-                         "together with number of threads");
-  }
+                       OperandRange numThreadsDimsValues) {
   if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
     return failure();
   return success();
@@ -2588,9 +2583,9 @@ verifyNumThreadsClause(Operation *op,
 
 LogicalResult ParallelOp::verify() {
   // verify num_threads clause restrictions
-  if (failed(verifyNumThreadsClause(
-          getOperation(), this->getNumThreadsNumDimsAttr(),
-          this->getNumThreadsDimsValues(), this->getNumThreads())))
+  if (failed(verifyNumThreadsClause(getOperation(),
+                                    this->getNumThreadsNumDimsAttr(),
+                                    this->getNumThreadsDimsValues())))
     return failure();
 
   // verify allocate clause restrictions
@@ -4629,33 +4624,28 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
 static ParseResult
 parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-                      SmallVectorImpl<Type> &types,
-                      std::optional<OpAsmParser::UnresolvedOperand> &bounds,
-                      Type &boundsType) {
+                      SmallVectorImpl<Type> &types) {
   if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
     return success();
   }
 
-  OpAsmParser::UnresolvedOperand boundsOperand;
-  if (parser.parseOperand(boundsOperand) || parser.parseColon() ||
-      parser.parseType(boundsType)) {
+  // Without dims modifier: value : type
+  OpAsmParser::UnresolvedOperand singleValue;
+  Type singleType;
+  if (parser.parseOperand(singleValue) || parser.parseColon() ||
+      parser.parseType(singleType)) {
     return failure();
   }
-  bounds = boundsOperand;
+  values.push_back(singleValue);
+  types.push_back(singleType);
   return success();
 }
 
 static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
                                   IntegerAttr dimsAttr, OperandRange values,
-                                  TypeRange types, Value bounds,
-                                  Type boundsType) {
-  if (!values.empty()) {
-    printDimsModifierWithValues(p, dimsAttr, values, types);
-  }
-  if (bounds) {
-    p.printOperand(bounds);
-    p << " : " << boundsType;
-  }
+                                  TypeRange types) {
+  // Multidimensional: dims(N): values : type
+  printDimsModifierWithValues(p, dimsAttr, values, types);
 }
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 67f30383bb03a..da44dda0a1230 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -3270,8 +3270,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   llvm::Value *numThreads = nullptr;
   // num_threads dims and values are not yet supported
   assert(!opInst.hasNumThreadsDimsModifier() &&
-         "Lowering of num_threads with dims modifier is NYI.");
-  if (auto numThreadsVar = opInst.getNumThreads())
+         "Lowering of num_threads with dims modifier is not yet implemented.");
+  if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0))
     numThreads = moduleTranslation.lookupValue(numThreadsVar);
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
   if (auto bind = opInst.getProcBindKind())
@@ -6055,8 +6055,9 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
           .Case([&](omp::ParallelOp parallelOp) {
             // num_threads dims and values are not yet supported
             assert(!parallelOp.hasNumThreadsDimsModifier() &&
-                   "Lowering of num_threads with dims modifier is NYI.");
-            if (parallelOp.getNumThreads() == blockArg)
+                   "Lowering of num_threads with dims modifier is not yet "
+                   "implemented.");
+            if (parallelOp.getNumThreadsDimsValue(0) == blockArg)
               numThreads = hostEvalVar;
             else
               llvm_unreachable("unsupported host_eval use");
@@ -6175,9 +6176,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
       // num_threads dims and values are not yet supported
-      assert(!parallelOp.hasNumThreadsDimsModifier() &&
-             "Lowering of num_threads with dims modifier is NYI.");
-      numThreads = parallelOp.getNumThreads();
+      assert(
+          !parallelOp.hasNumThreadsDimsModifier() &&
+          "Lowering of num_threads with dims modifier is not yet implemented.");
+      numThreads = parallelOp.getNumThreadsDimsValue(0);
     }
   }
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 1c5ef785a17f9..8a5e64b1a98ca 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -34,7 +34,7 @@ func.func @num_threads_dims_no_values() {
   // expected-error at +1 {{dims modifier requires values to be specified}}
   "omp.parallel"() ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
+  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
   return
 }
 
@@ -51,11 +51,11 @@ func.func @num_threads_dims_mismatch(%n : i64) {
 
 // -----
 
-func.func @num_threads_dims_and_scalar(%n : i64, %m: i64) {
-  // expected-error at +1 {{num_threads with dims modifier cannot be used together with number of threads}}
-  "omp.parallel"(%n, %n, %m) ({
+func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) {
+  // expected-error at +1 {{dims values can only be specified with dims modifier}}
+  "omp.parallel"(%n, %m) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,2,1,0,0>, num_threads_num_dims = 2 : i64} : (i64, i64, i64) -> ()
+  }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> ()
   return
 }
 
@@ -2739,7 +2739,7 @@ func.func @undefined_privatizer(%arg0: index) {
 // -----
 func.func @undefined_privatizer(%arg0: !llvm.ptr) {
   // expected-error @below {{inconsistent number of private variables and privatizer op symbols, private vars: 1 vs. privatizer op symbols: 2}}
-  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
+  "omp.parallel"(%arg0) <{operandSegmentSizes = array<i32: 0, 0, 0, 0, 1, 0>, private_syms = [@x.privatizer, @y.privatizer]}> ({
     ^bb0(%arg2: !llvm.ptr):
       omp.terminator
     }) : (!llvm.ptr) -> ()
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3acbe010c28a5..4c57b8aea0b48 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -73,7 +73,7 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) num_threads(%{{.*}} : i32)
     "omp.parallel"(%data_var, %data_var, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,0,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,0,1,0,0>} : (memref<i32>, memref<i32>, i32) -> ()
 
   // CHECK: omp.barrier
     omp.barrier
@@ -82,22 +82,22 @@ func.func @omp_parallel(%data_var : memref<i32>, %if_cond : i1, %num_threads : i
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>) if(%{{.*}})
     "omp.parallel"(%data_var, %data_var, %if_cond) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
+    }) {operandSegmentSizes = array<i32: 1,1,1,0,0,0>} : (memref<i32>, memref<i32>, i1) -> ()
 
   // test without allocate
   // CHECK: omp.parallel if(%{{.*}}) num_threads(%{{.*}} : i32)
     "omp.parallel"(%if_cond, %num_threads) ({
       omp.terminator
-    }) {operandSegmentSizes = array<i32: 0,0,1,0,1,0,0>} : (i1, i32) -> ()
+    }) {operandSegmentSizes = array<i32: 0,0,1,1,0,0>} : (i1, i32) -> ()
 
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,1,0,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,1,1,0,0>, proc_bind_kind = #omp<procbindkind spread>} : (memref<i32>, memref<i32>, i1, i32) -> ()
 
   // test with multiple parameters for single variadic argument
   // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
   "omp.parallel" (%data_var, %data_var) ({
     omp.terminator
-  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
+  }) {operandSegmentSizes = array<i32: 1,1,0,0,0,0>} : (memref<i32>, memref<i32>) -> ()
 
   // CHECK: omp.parallel
   omp.parallel {

>From 038f9f4b3cfd4664f4df95e141178c6289194ac4 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Wed, 14 Jan 2026 12:07:56 +0530
Subject: [PATCH 5/7] fix adding numThreadsNumDims to ParallelOperands apply
 method

---
 flang/lib/Lower/OpenMP/OpenMP.cpp | 1 +
 1 file changed, 1 insertion(+)

diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index bdbabc292349a..5ca228e218c37 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -177,6 +177,7 @@ class HostEvalInfo {
 
     parallelApplied = true;
     clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
+    clauseOps.numThreadsNumDims = ops.numThreadsNumDims;
     return true;
   }
 

>From 12c4749a7dc638ea4f22f2e1dd9cf9fd987f5123 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Fri, 16 Jan 2026 12:32:56 +0530
Subject: [PATCH 6/7] Remove dims(N) syntax and use list of vals for
 num_threads

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp    |  2 +-
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 14 +++--
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 53 +++++++++----------
 .../Conversion/SCFToOpenMP/SCFToOpenMP.cpp    |  3 +-
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 51 +++++-------------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 26 ++++-----
 mlir/test/Dialect/OpenMP/invalid.mlir         | 31 -----------
 mlir/test/Dialect/OpenMP/ops.mlir             | 11 +++-
 mlir/test/Target/LLVMIR/openmp-todo.mlir      | 11 ++++
 9 files changed, 76 insertions(+), 126 deletions(-)

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index abaeaa90f80be..90825a3653016 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -516,7 +516,7 @@ bool ClauseProcessor::processNumThreads(
     mlir::omp::NumThreadsClauseOps &result) const {
   if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
     // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
-    result.numThreadsDimsValues.push_back(
+    result.numThreadsVals.push_back(
         fir::getBase(converter.genExprValue(clause->v, stmtCtx)));
     return true;
   }
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 5ca228e218c37..c9271925580cd 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -99,7 +99,7 @@ class HostEvalInfo {
     if (ops.numTeamsUpper)
       vars.push_back(ops.numTeamsUpper);
 
-    for (auto numThreads : ops.numThreadsDimsValues)
+    for (auto numThreads : ops.numThreadsVals)
       vars.push_back(numThreads);
 
     if (ops.threadLimit)
@@ -115,8 +115,7 @@ class HostEvalInfo {
     assert(args.size() ==
                ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
                    ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
-                   (ops.numTeamsUpper ? 1 : 0) +
-                   ops.numThreadsDimsValues.size() +
+                   (ops.numTeamsUpper ? 1 : 0) + ops.numThreadsVals.size() +
                    (ops.threadLimit ? 1 : 0) &&
            "invalid block argument list");
     int argIndex = 0;
@@ -135,8 +134,8 @@ class HostEvalInfo {
     if (ops.numTeamsUpper)
       ops.numTeamsUpper = args[argIndex++];
 
-    for (size_t i = 0; i < ops.numThreadsDimsValues.size(); ++i)
-      ops.numThreadsDimsValues[i] = args[argIndex++];
+    for (size_t i = 0; i < ops.numThreadsVals.size(); ++i)
+      ops.numThreadsVals[i] = args[argIndex++];
 
     if (ops.threadLimit)
       ops.threadLimit = args[argIndex++];
@@ -170,14 +169,13 @@ class HostEvalInfo {
   /// \returns whether an update was performed. If not, these clauses were not
   ///          evaluated in the host device.
   bool apply(mlir::omp::ParallelOperands &clauseOps) {
-    if (ops.numThreadsDimsValues.empty() || parallelApplied) {
+    if (ops.numThreadsVals.empty() || parallelApplied) {
       parallelApplied = true;
       return false;
     }
 
     parallelApplied = true;
-    clauseOps.numThreadsDimsValues = ops.numThreadsDimsValues;
-    clauseOps.numThreadsNumDims = ops.numThreadsNumDims;
+    clauseOps.numThreadsVals = ops.numThreadsVals;
     return true;
   }
 
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 8be7030599cc6..90bff92fbc826 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1069,53 +1069,48 @@ class OpenMP_NumThreadsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
-    ConfinedAttr<OptionalAttr<I64Attr>, [IntPositive]>:$num_threads_num_dims,
-    Variadic<IntLikeType>:$num_threads_dims_values
+    Variadic<IntLikeType>:$num_threads_vals
   );
 
   let optAssemblyFormat = [{
     `num_threads` `(` custom<NumThreadsClause>(
-      $num_threads_num_dims, $num_threads_dims_values, type($num_threads_dims_values)
+      $num_threads_vals, type($num_threads_vals)
     ) `)`
   }];
 
   let description = [{
-    num_threads clause specifies the desired number of threads in the team
-    space formed by the construct on which it appears.
-
-    With dims modifier:
-    - Uses `num_threads_num_dims` (dimension count) and `num_threads_dims_values` (upper bounds list)
-    - Specifies upper bounds for each dimension (all must have same type)
-    - Format: `num_threads(dims(N): upper_bound_0, ..., upper_bound_N-1 : type)`
-    - Example: `num_threads(dims(3): %ub0, %ub1, %ub2 : i32)`
-
-    Without dims modifier:
-    - The number of threads is specified by single value in `num_threads_dims_values`
-    - Format: `num_threads(value : type)`
+    The `num_threads` clause specifies the number of threads.
+
+    Multi-dimensional format (dims modifier):
+    - Multiple values can be specified for multi-dimensional thread counts.
+    - The number of dimensions is derived from the number of values.
+    - Values can have different integer types.
+    - Format: `num_threads(%v1, %v2, ... : type1, type2, ...)`
+    - Example: `num_threads(%n, %m : i32, i64)`
+
+    Single value format:
+    - A single value specifies the number of threads.
+    - Format: `num_threads(%value : type)`
     - Example: `num_threads(%n : i32)`
   }];
 
   let extraClassDeclaration = [{
-    /// Returns true if the dims modifier is explicitly present
-    bool hasNumThreadsDimsModifier() {
-      return getNumThreadsNumDims().has_value() && getNumThreadsNumDims().value();
+    /// Returns true if using multi-dimensional values (more than one value)
+    bool hasNumThreadsMultiDim() {
+      return getNumThreadsVals().size() > 1;
     }
 
-    /// Returns the number of dimensions specified by dims modifier
+    /// Returns the number of dimensions specified for num_threads
     unsigned getNumThreadsDimsCount() {
-      if (!hasNumThreadsDimsModifier())
-        return 1;
-      return static_cast<unsigned>(*getNumThreadsNumDims());
+      return getNumThreadsVals().size();
     }
 
     /// Returns the value for a specific dimension index
-    /// Index must be less than getNumThreadsDimsCount()
-    ::mlir::Value getNumThreadsDimsValue(unsigned index) {
-      assert(index < getNumThreadsDimsCount() &&
-             "Num threads dims index out of bounds");
-      if(getNumThreadsDimsValues().empty())
-        return nullptr;
-      return getNumThreadsDimsValues()[index];
+    /// Index must be less than getNumThreadsVals().size()
+    ::mlir::Value getNumThreadsVal(unsigned index) {
+      assert(index < getNumThreadsVals().size() &&
+             "Num threads index out of bounds");
+      return getNumThreadsVals()[index];
     }
   }];
 }
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 5d75613f9b2b6..6ba2155c7840f 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -450,8 +450,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
-        /* num_threads_num_dims = */ nullptr,
-        /* num_threads_dims_values = */ numThreadsValues,
+        /* num_threads_vals = */ numThreadsValues,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
         /* private_needs_barrier = */ nullptr,
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 6911272d43f6e..bc7647d129f60 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2252,8 +2252,7 @@ LogicalResult TargetOp::verifyRegions() {
       if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
         if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
             parallelOp->isAncestor(capturedOp) &&
-            llvm::is_contained(parallelOp.getNumThreadsDimsValues(),
-                               hostEvalArg))
+            llvm::is_contained(parallelOp.getNumThreadsVals(), hostEvalArg))
           continue;
 
         return emitOpError()
@@ -2505,8 +2504,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        ArrayRef<NamedAttribute> attributes) {
   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
-                    /*num_threads_dims=*/nullptr,
-                    /*num_threads_values=*/ValueRange(),
+                    /*num_threads_vals=*/ValueRange(),
                     /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
                     /*proc_bind_kind=*/nullptr,
@@ -2519,8 +2517,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                    clauses.ifExpr, clauses.numThreadsNumDims,
-                    clauses.numThreadsDimsValues, clauses.privateVars,
+                    clauses.ifExpr, clauses.numThreadsVals, clauses.privateVars,
                     makeArrayAttr(ctx, clauses.privateSyms),
                     clauses.privateNeedsBarrier, clauses.procBindKind,
                     clauses.reductionMod, clauses.reductionVars,
@@ -2571,23 +2568,7 @@ static LogicalResult verifyPrivateVarList(OpType &op) {
   return success();
 }
 
-// Helper: Verify num_threads clause
-LogicalResult
-verifyNumThreadsClause(Operation *op,
-                       std::optional<IntegerAttr> numThreadsNumDims,
-                       OperandRange numThreadsDimsValues) {
-  if (failed(verifyDimsModifier(op, numThreadsNumDims, numThreadsDimsValues)))
-    return failure();
-  return success();
-}
-
 LogicalResult ParallelOp::verify() {
-  // verify num_threads clause restrictions
-  if (failed(verifyNumThreadsClause(getOperation(),
-                                    this->getNumThreadsNumDimsAttr(),
-                                    this->getNumThreadsDimsValues())))
-    return failure();
-
   // verify allocate clause restrictions
   if (getAllocateVars().size() != getAllocatorVars().size())
     return emitError(
@@ -4622,30 +4603,24 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
 // Parser and printer for num_threads clause
 //===----------------------------------------------------------------------===//
 static ParseResult
-parseNumThreadsClause(OpAsmParser &parser, IntegerAttr &dimsAttr,
+parseNumThreadsClause(OpAsmParser &parser,
                       SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
                       SmallVectorImpl<Type> &types) {
-  if (succeeded(parseDimsModifierWithValues(parser, dimsAttr, values, types))) {
-    return success();
-  }
-
-  // Without dims modifier: value : type
-  OpAsmParser::UnresolvedOperand singleValue;
-  Type singleType;
-  if (parser.parseOperand(singleValue) || parser.parseColon() ||
-      parser.parseType(singleType)) {
+  // Parse comma-separated list of values with their types
+  // Format: %v1, %v2, ... : type1, type2, ...
+  if (parser.parseOperandList(values) || parser.parseColon() ||
+      parser.parseTypeList(types)) {
     return failure();
   }
-  values.push_back(singleValue);
-  types.push_back(singleType);
   return success();
 }
 
 static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
-                                  IntegerAttr dimsAttr, OperandRange values,
-                                  TypeRange types) {
-  // Multidimensional: dims(N): values : type
-  printDimsModifierWithValues(p, dimsAttr, values, types);
+                                  OperandRange values, TypeRange types) {
+  // Print values with their types
+  llvm::interleaveComma(values, p, [&](Value v) { p << v; });
+  p << " : ";
+  llvm::interleaveComma(types, p, [&](Type t) { p << t; });
 }
 
 #define GET_ATTRDEF_CLASSES
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index da44dda0a1230..2fd3da1b5b30a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,6 +380,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.hasNumTeamsMultiDim())
       result = todo("num_teams with multi-dimensional values");
   };
+  auto checkNumThreadsMultiDim = [&todo](auto op, LogicalResult &result) {
+    if (op.hasNumThreadsMultiDim())
+      result = todo("num_threads with multi-dimensional values");
+  };
 
   LogicalResult result = success();
   llvm::TypeSwitch<Operation &>(op)
@@ -431,6 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       .Case([&](omp::ParallelOp op) {
         checkAllocate(op, result);
         checkReduction(op, result);
+        checkNumThreadsMultiDim(op, result);
       })
       .Case([&](omp::SimdOp op) { checkReduction(op, result); })
       .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
@@ -3268,11 +3273,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   if (auto ifVar = opInst.getIfExpr())
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
-  // num_threads dims and values are not yet supported
-  assert(!opInst.hasNumThreadsDimsModifier() &&
-         "Lowering of num_threads with dims modifier is not yet implemented.");
-  if (auto numThreadsVar = opInst.getNumThreadsDimsValue(0))
-    numThreads = moduleTranslation.lookupValue(numThreadsVar);
+  if (!opInst.getNumThreadsVals().empty())
+    numThreads = moduleTranslation.lookupValue(opInst.getNumThreadsVal(0));
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
   if (auto bind = opInst.getProcBindKind())
     pbKind = getProcBindKind(*bind);
@@ -6053,11 +6055,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
-            // num_threads dims and values are not yet supported
-            assert(!parallelOp.hasNumThreadsDimsModifier() &&
-                   "Lowering of num_threads with dims modifier is not yet "
-                   "implemented.");
-            if (parallelOp.getNumThreadsDimsValue(0) == blockArg)
+            if (!parallelOp.getNumThreadsVals().empty() &&
+                parallelOp.getNumThreadsVal(0) == blockArg)
               numThreads = hostEvalVar;
             else
               llvm_unreachable("unsupported host_eval use");
@@ -6175,11 +6174,8 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
     }
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
-      // num_threads dims and values are not yet supported
-      assert(
-          !parallelOp.hasNumThreadsDimsModifier() &&
-          "Lowering of num_threads with dims modifier is not yet implemented.");
-      numThreads = parallelOp.getNumThreadsDimsValue(0);
+      if (!parallelOp.getNumThreadsVals().empty())
+        numThreads = parallelOp.getNumThreadsVal(0);
     }
   }
 
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index 8a5e64b1a98ca..bb882db73cbab 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -30,37 +30,6 @@ func.func @num_threads_once(%n : si32) {
 
 // -----
 
-func.func @num_threads_dims_no_values() {
-  // expected-error at +1 {{dims modifier requires values to be specified}}
-  "omp.parallel"() ({
-    omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,0,0,0>, num_threads_num_dims = 2 : i64} : () -> ()
-  return
-}
-
-// -----
-
-func.func @num_threads_dims_mismatch(%n : i64) {
-  // expected-error at +1 {{dims(2) specified but 1 values provided}}
-  omp.parallel num_threads(dims(2): %n : i64) {
-    omp.terminator
-  }
-
-  return
-}
-
-// -----
-
-func.func @num_threads_multiple_values_without_dims(%n : i64, %m: i64) {
-  // expected-error at +1 {{dims values can only be specified with dims modifier}}
-  "omp.parallel"(%n, %m) ({
-    omp.terminator
-  }) {operandSegmentSizes = array<i32: 0,0,0,2,0,0>} : (i64, i64) -> ()
-  return
-}
-
-// -----
-
 func.func @nowait_not_allowed(%n : memref<i32>) {
   // expected-error at +1 {{expected '{' to begin a region}}
   omp.parallel nowait {}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 4c57b8aea0b48..67f93869d4be7 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -160,8 +160,15 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
    omp.terminator
  }
 
- // CHECK: omp.parallel num_threads(dims(2): %{{.*}}, %{{.*}} : i64)
- omp.parallel num_threads(dims(2): %n_i64, %n_i64 : i64) {
+ // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}} : i64, i64)
+ omp.parallel num_threads(%n_i64, %n_i64 : i64, i64) {
+   omp.terminator
+ }
+
+ %n_i16 = arith.constant 8 : i16
+ // Test num_threads with mixed types.
+ // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}}, %{{.*}} : i32, i64, i16)
+ omp.parallel num_threads(%num_threads, %n_i64, %n_i16 : i32, i64, i16) {
    omp.terminator
  }
 
diff --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 3681ce38bd523..fd218e91d0b46 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -452,6 +452,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
 
 // -----
 
+llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
+  // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
+  omp.parallel num_threads(%lb, %ub : i32, i32) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
   // expected-error at below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}

>From 1033cc66ab5617df178499b6138a6f00a7da18f5 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Sat, 17 Jan 2026 10:37:09 +0530
Subject: [PATCH 7/7] remove custom parser printer for num_threads

---
 .../mlir/Dialect/OpenMP/OpenMPClauses.td      | 10 ++++----
 mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp  | 24 -------------------
 .../OpenMP/OpenMPToLLVMIRTranslation.cpp      | 10 ++++----
 3 files changed, 9 insertions(+), 35 deletions(-)

diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 90bff92fbc826..7d0e1e3f91af4 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1073,9 +1073,7 @@ class OpenMP_NumThreadsClauseSkip<
   );
 
   let optAssemblyFormat = [{
-    `num_threads` `(` custom<NumThreadsClause>(
-      $num_threads_vals, type($num_threads_vals)
-    ) `)`
+    `num_threads` `(` $num_threads_vals `:` type($num_threads_vals) `)`
   }];
 
   let description = [{
@@ -1107,10 +1105,10 @@ class OpenMP_NumThreadsClauseSkip<
 
     /// Returns the value for a specific dimension index
     /// Index must be less than getNumThreadsVals().size()
-    ::mlir::Value getNumThreadsVal(unsigned index) {
-      assert(index < getNumThreadsVals().size() &&
+    ::mlir::Value getNumThreads(unsigned dim = 0) {
+      assert(dim < getNumThreadsDimsCount() &&
              "Num threads index out of bounds");
-      return getNumThreadsVals()[index];
+      return getNumThreadsVals()[dim];
     }
   }];
 }
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index bc7647d129f60..ab1038c755f7a 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -4599,30 +4599,6 @@ static void printNumTeamsClause(OpAsmPrinter &p, Operation *op,
   }
 }
 
-//===----------------------------------------------------------------------===//
-// Parser and printer for num_threads clause
-//===----------------------------------------------------------------------===//
-static ParseResult
-parseNumThreadsClause(OpAsmParser &parser,
-                      SmallVectorImpl<OpAsmParser::UnresolvedOperand> &values,
-                      SmallVectorImpl<Type> &types) {
-  // Parse comma-separated list of values with their types
-  // Format: %v1, %v2, ... : type1, type2, ...
-  if (parser.parseOperandList(values) || parser.parseColon() ||
-      parser.parseTypeList(types)) {
-    return failure();
-  }
-  return success();
-}
-
-static void printNumThreadsClause(OpAsmPrinter &p, Operation *op,
-                                  OperandRange values, TypeRange types) {
-  // Print values with their types
-  llvm::interleaveComma(values, p, [&](Value v) { p << v; });
-  p << " : ";
-  llvm::interleaveComma(types, p, [&](Type t) { p << t; });
-}
-
 #define GET_ATTRDEF_CLASSES
 #include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
 
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 2fd3da1b5b30a..b92ec9332d43a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,7 +380,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.hasNumTeamsMultiDim())
       result = todo("num_teams with multi-dimensional values");
   };
-  auto checkNumThreadsMultiDim = [&todo](auto op, LogicalResult &result) {
+  auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
     if (op.hasNumThreadsMultiDim())
       result = todo("num_threads with multi-dimensional values");
   };
@@ -435,7 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       .Case([&](omp::ParallelOp op) {
         checkAllocate(op, result);
         checkReduction(op, result);
-        checkNumThreadsMultiDim(op, result);
+        checkNumThreads(op, result);
       })
       .Case([&](omp::SimdOp op) { checkReduction(op, result); })
       .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
@@ -3274,7 +3274,7 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
   if (!opInst.getNumThreadsVals().empty())
-    numThreads = moduleTranslation.lookupValue(opInst.getNumThreadsVal(0));
+    numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
   if (auto bind = opInst.getProcBindKind())
     pbKind = getProcBindKind(*bind);
@@ -6056,7 +6056,7 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
           })
           .Case([&](omp::ParallelOp parallelOp) {
             if (!parallelOp.getNumThreadsVals().empty() &&
-                parallelOp.getNumThreadsVal(0) == blockArg)
+                parallelOp.getNumThreads(0) == blockArg)
               numThreads = hostEvalVar;
             else
               llvm_unreachable("unsupported host_eval use");
@@ -6175,7 +6175,7 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
 
     if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
       if (!parallelOp.getNumThreadsVals().empty())
-        numThreads = parallelOp.getNumThreadsVal(0);
+        numThreads = parallelOp.getNumThreads(0);
     }
   }
 



More information about the llvm-branch-commits mailing list