[Mlir-commits] [mlir] 08654ad - [OpenMP][MLIR] Add num_threads clause with dims modifier support (#171767)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Tue Jan 27 02:01:03 PST 2026


Author: Chaitanya
Date: 2026-01-27T15:30:55+05:30
New Revision: 08654adc62c1d7c9a0f8fe40138fe4317a941f3b

URL: https://github.com/llvm/llvm-project/commit/08654adc62c1d7c9a0f8fe40138fe4317a941f3b
DIFF: https://github.com/llvm/llvm-project/commit/08654adc62c1d7c9a0f8fe40138fe4317a941f3b.diff

LOG: [OpenMP][MLIR] Add num_threads clause with dims modifier support (#171767)

PR adds support of openmp 6.1 feature num_threads with dims modifier.
llvmIR translation for num_threads with dims modifier is marked as NYI.

Added: 
    flang/test/Lower/OpenMP/num-threads-dims.f90

Modified: 
    flang/lib/Lower/OpenMP/ClauseProcessor.cpp
    flang/lib/Lower/OpenMP/Clauses.cpp
    flang/lib/Lower/OpenMP/OpenMP.cpp
    llvm/include/llvm/Frontend/OpenMP/ClauseT.h
    mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
    mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
    mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
    mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
    mlir/test/Dialect/OpenMP/ops.mlir
    mlir/test/Target/LLVMIR/openmp-todo.mlir

Removed: 
    


################################################################################
diff  --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index a067db300a31d..b29cfac857841 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -525,9 +525,11 @@ bool ClauseProcessor::processNumThreads(
     lower::StatementContext &stmtCtx,
     mlir::omp::NumThreadsClauseOps &result) const {
   if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
-    // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
-    result.numThreads =
-        fir::getBase(converter.genExprValue(clause->v, stmtCtx));
+    // OMPIRBuilder expects `NUM_THREADS` clause as a list of Values.
+    for (const ExprTy &expr : clause->v) {
+      result.numThreadsVars.push_back(
+          fir::getBase(converter.genExprValue(expr, stmtCtx)));
+    }
     return true;
   }
   return false;

diff  --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index e588d33bda4e5..9a20fb5d006f8 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1302,9 +1302,16 @@ NumTeams make(const parser::OmpClause::NumTeams &inp,
 NumThreads make(const parser::OmpClause::NumThreads &inp,
                 semantics::SemanticsContext &semaCtx) {
   // inp.v -> parser::OmpNumThreadsClause
-  auto &t1 = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
-  assert(!t1.empty());
-  return NumThreads{/*Nthreads=*/makeExpr(t1.front(), semaCtx)};
+  // With dims modifier (OpenMP 6.1): multiple values
+  // Without dims modifier: single value
+  auto &values = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
+  assert(!values.empty());
+
+  List<NumThreads::Nthreads> v;
+  for (const auto &val : values) {
+    v.push_back(makeExpr(val, semaCtx));
+  }
+  return NumThreads{/*Nthreads=*/v};
 }
 
 // OmpxAttribute: empty

diff  --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 8757783945bd9..1902bd7a21f4a 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -100,8 +100,8 @@ class HostEvalInfo {
     for (auto numTeamsUpper : ops.numTeamsUpperVars)
       vars.push_back(numTeamsUpper);
 
-    if (ops.numThreads)
-      vars.push_back(ops.numThreads);
+    for (auto numThreads : ops.numThreadsVars)
+      vars.push_back(numThreads);
 
     if (ops.threadLimit)
       vars.push_back(ops.threadLimit);
@@ -116,7 +116,7 @@ class HostEvalInfo {
     assert(args.size() ==
                ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
                    ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
-                   ops.numTeamsUpperVars.size() + (ops.numThreads ? 1 : 0) +
+                   ops.numTeamsUpperVars.size() + ops.numThreadsVars.size() +
                    (ops.threadLimit ? 1 : 0) &&
            "invalid block argument list");
     int argIndex = 0;
@@ -135,8 +135,8 @@ class HostEvalInfo {
     for (size_t i = 0; i < ops.numTeamsUpperVars.size(); ++i)
       ops.numTeamsUpperVars[i] = args[argIndex++];
 
-    if (ops.numThreads)
-      ops.numThreads = args[argIndex++];
+    for (size_t i = 0; i < ops.numThreadsVars.size(); ++i)
+      ops.numThreadsVars[i] = args[argIndex++];
 
     if (ops.threadLimit)
       ops.threadLimit = args[argIndex++];
@@ -170,13 +170,13 @@ class HostEvalInfo {
   /// \returns whether an update was performed. If not, these clauses were not
   ///          evaluated in the host device.
   bool apply(mlir::omp::ParallelOperands &clauseOps) {
-    if (!ops.numThreads || parallelApplied) {
+    if (ops.numThreadsVars.empty() || parallelApplied) {
       parallelApplied = true;
       return false;
     }
 
     parallelApplied = true;
-    clauseOps.numThreads = ops.numThreads;
+    clauseOps.numThreadsVars = ops.numThreadsVars;
     return true;
   }
 

diff  --git a/flang/test/Lower/OpenMP/num-threads-dims.f90 b/flang/test/Lower/OpenMP/num-threads-dims.f90
new file mode 100644
index 0000000000000..f3a8d706b7283
--- /dev/null
+++ b/flang/test/Lower/OpenMP/num-threads-dims.f90
@@ -0,0 +1,61 @@
+! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=61 %s -o - | FileCheck %s
+
+!===============================================================================
+! `num_threads` clause with dims modifier (OpenMP 6.1)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPparallel_numthreads_dims4
+subroutine parallel_numthreads_dims4()
+  ! CHECK: omp.parallel
+  ! CHECK-SAME: num_threads(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32, i32)
+  !$omp parallel num_threads(dims(4): 4, 5, 6, 7)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end parallel
+end subroutine parallel_numthreads_dims4
+
+! CHECK-LABEL: func @_QPparallel_numthreads_dims2
+subroutine parallel_numthreads_dims2()
+  ! CHECK: omp.parallel
+  ! CHECK-SAME: num_threads(%{{.*}}, %{{.*}} : i32, i32)
+  !$omp parallel num_threads(dims(2): 8, 4)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end parallel
+end subroutine parallel_numthreads_dims2
+
+! CHECK-LABEL: func @_QPparallel_numthreads_dims_var
+subroutine parallel_numthreads_dims_var(a, b, c)
+  integer, intent(in) :: a, b, c
+  ! CHECK: omp.parallel
+  ! CHECK-SAME: num_threads(%{{.*}}, %{{.*}}, %{{.*}} : i32, i32, i32)
+  !$omp parallel num_threads(dims(3): a, b, c)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end parallel
+end subroutine parallel_numthreads_dims_var
+
+!===============================================================================
+! `num_threads` clause without dims modifier (legacy)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPparallel_numthreads_legacy
+subroutine parallel_numthreads_legacy(n)
+  integer, intent(in) :: n
+  ! CHECK: omp.parallel
+  ! CHECK-SAME: num_threads(%{{.*}} : i32)
+  !$omp parallel num_threads(n)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end parallel
+end subroutine parallel_numthreads_legacy
+
+! CHECK-LABEL: func @_QPparallel_numthreads_const
+subroutine parallel_numthreads_const()
+  ! CHECK: omp.parallel
+  ! CHECK-SAME: num_threads(%{{.*}} : i32)
+  !$omp parallel num_threads(16)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end parallel
+end subroutine parallel_numthreads_const

diff  --git a/llvm/include/llvm/Frontend/OpenMP/ClauseT.h b/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
index 32cb1be416eb0..86353f0fe0bda 100644
--- a/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
+++ b/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
@@ -1005,11 +1005,13 @@ struct NumTeamsT {
 };
 
 // V5.2: [10.1.2] `num_threads` clause
+// V6.1: Extended with dims modifier support
 template <typename T, typename I, typename E> //
 struct NumThreadsT {
   using Nthreads = E;
+  using List = ListT<Nthreads>;
   using WrapperTrait = std::true_type;
-  Nthreads v;
+  List v;
 };
 
 template <typename T, typename I, typename E> //

diff  --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
index 2f7169c6f3e2c..0bfe36648c54d 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauses.td
@@ -1107,16 +1107,47 @@ class OpenMP_NumThreadsClauseSkip<
   > : OpenMP_Clause<traits, arguments, assemblyFormat, description,
                     extraClassDeclaration> {
   let arguments = (ins
-    Optional<IntLikeType>:$num_threads
+    Variadic<IntLikeType>:$num_threads_vars
   );
 
   let optAssemblyFormat = [{
-    `num_threads` `(` $num_threads `:` type($num_threads) `)`
+    `num_threads` `(` $num_threads_vars `:` type($num_threads_vars) `)`
   }];
 
   let description = [{
-    The optional `num_threads` parameter specifies the number of threads which
-    should be used to execute the parallel region.
+    The `num_threads` clause specifies the number of threads.
+
+    Multi-dimensional format (dims modifier):
+    - Multiple values can be specified for multi-dimensional thread counts.
+    - The number of dimensions is derived from the number of values.
+    - Values can have 
diff erent integer types.
+    - Format: `num_threads(%v1, %v2, ... : type1, type2, ...)`
+    - Example: `num_threads(%n, %m : i32, i64)`
+
+    Single value format:
+    - A single value specifies the number of threads.
+    - Format: `num_threads(%value : type)`
+    - Example: `num_threads(%n : i32)`
+  }];
+
+  let extraClassDeclaration = [{
+    /// Returns true if using multi-dimensional values (more than one value)
+    bool hasNumThreadsMultiDim() {
+      return getNumThreadsVars().size() > 1;
+    }
+
+    /// Returns the number of dimensions specified for num_threads
+    unsigned getNumThreadsDimsCount() {
+      return getNumThreadsVars().size();
+    }
+
+    /// Returns the value for a specific dimension index
+    /// Index must be less than getNumThreadsVars().size()
+    ::mlir::Value getNumThreads(unsigned dim = 0) {
+      assert(dim < getNumThreadsDimsCount() &&
+             "Num threads index out of bounds");
+      return getNumThreadsVars()[dim];
+    }
   }];
 }
 

diff  --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 5fcaea7f39c3c..48845734e9547 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -486,10 +486,11 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
     }
     rewriter.eraseOp(reduce);
 
-    Value numThreadsVar;
+    SmallVector<Value> numThreadsVars;
     if (numThreads > 0) {
-      numThreadsVar = LLVM::ConstantOp::create(
+      Value numThreadsVar = LLVM::ConstantOp::create(
           rewriter, loc, rewriter.getI32IntegerAttr(numThreads));
+      numThreadsVars.push_back(numThreadsVar);
     }
     // Create the parallel wrapper.
     auto ompParallel = omp::ParallelOp::create(
@@ -497,7 +498,7 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
         /* allocate_vars = */ llvm::SmallVector<Value>{},
         /* allocator_vars = */ llvm::SmallVector<Value>{},
         /* if_expr = */ Value{},
-        /* num_threads = */ numThreadsVar,
+        /* num_threads_vars = */ numThreadsVars,
         /* private_vars = */ ValueRange(),
         /* private_syms = */ nullptr,
         /* private_needs_barrier = */ nullptr,

diff  --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 8d06a9c66808a..349fc300e644d 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -2252,7 +2252,7 @@ LogicalResult TargetOp::verifyRegions() {
       if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
         if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
             parallelOp->isAncestor(capturedOp) &&
-            hostEvalArg == parallelOp.getNumThreads())
+            llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
           continue;
 
         return emitOpError()
@@ -2504,7 +2504,8 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        ArrayRef<NamedAttribute> attributes) {
   ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
                     /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
-                    /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
+                    /*num_threads_vars=*/ValueRange(),
+                    /*private_vars=*/ValueRange(),
                     /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
                     /*proc_bind_kind=*/nullptr,
                     /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
@@ -2516,7 +2517,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
                        const ParallelOperands &clauses) {
   MLIRContext *ctx = builder.getContext();
   ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
-                    clauses.ifExpr, clauses.numThreads, clauses.privateVars,
+                    clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
                     makeArrayAttr(ctx, clauses.privateSyms),
                     clauses.privateNeedsBarrier, clauses.procBindKind,
                     clauses.reductionMod, clauses.reductionVars,

diff  --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 5af1dea34f482..924108d5702f7 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -380,6 +380,10 @@ static LogicalResult checkImplementationStatus(Operation &op) {
     if (op.hasNumTeamsMultiDim())
       result = todo("num_teams with multi-dimensional values");
   };
+  auto checkNumThreads = [&todo](auto op, LogicalResult &result) {
+    if (op.hasNumThreadsMultiDim())
+      result = todo("num_threads with multi-dimensional values");
+  };
 
   LogicalResult result = success();
   llvm::TypeSwitch<Operation &>(op)
@@ -431,6 +435,7 @@ static LogicalResult checkImplementationStatus(Operation &op) {
       .Case([&](omp::ParallelOp op) {
         checkAllocate(op, result);
         checkReduction(op, result);
+        checkNumThreads(op, result);
       })
       .Case([&](omp::SimdOp op) { checkReduction(op, result); })
       .Case<omp::AtomicReadOp, omp::AtomicWriteOp, omp::AtomicUpdateOp,
@@ -3268,8 +3273,8 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
   if (auto ifVar = opInst.getIfExpr())
     ifCond = moduleTranslation.lookupValue(ifVar);
   llvm::Value *numThreads = nullptr;
-  if (auto numThreadsVar = opInst.getNumThreads())
-    numThreads = moduleTranslation.lookupValue(numThreadsVar);
+  if (!opInst.getNumThreadsVars().empty())
+    numThreads = moduleTranslation.lookupValue(opInst.getNumThreads(0));
   auto pbKind = llvm::omp::OMP_PROC_BIND_default;
   if (auto bind = opInst.getProcBindKind())
     pbKind = getProcBindKind(*bind);
@@ -6051,7 +6056,8 @@ extractHostEvalClauses(omp::TargetOp targetOp, Value &numThreads,
               llvm_unreachable("unsupported host_eval use");
           })
           .Case([&](omp::ParallelOp parallelOp) {
-            if (parallelOp.getNumThreads() == blockArg)
+            if (!parallelOp.getNumThreadsVars().empty() &&
+                parallelOp.getNumThreads(0) == blockArg)
               numThreads = hostEvalVar;
             else
               llvm_unreachable("unsupported host_eval use");
@@ -6170,8 +6176,10 @@ initTargetDefaultAttrs(omp::TargetOp targetOp, Operation *capturedOp,
       threadLimit = teamsOp.getThreadLimit();
     }
 
-    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp))
-      numThreads = parallelOp.getNumThreads();
+    if (auto parallelOp = castOrGetParentOfType<omp::ParallelOp>(capturedOp)) {
+      if (!parallelOp.getNumThreadsVars().empty())
+        numThreads = parallelOp.getNumThreads(0);
+    }
   }
 
   // Handle clauses impacting the number of teams.

diff  --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 3afab239e19c5..563729039c762 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -160,6 +160,18 @@ func.func @omp_parallel_pretty(%data_var : memref<i32>, %if_cond : i1, %num_thre
    omp.terminator
  }
 
+ // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}} : i64, i64)
+ omp.parallel num_threads(%n_i64, %n_i64 : i64, i64) {
+   omp.terminator
+ }
+
+ %n_i16 = arith.constant 8 : i16
+ // Test num_threads with mixed types.
+ // CHECK: omp.parallel num_threads(%{{.*}}, %{{.*}}, %{{.*}} : i32, i64, i16)
+ omp.parallel num_threads(%num_threads, %n_i64, %n_i16 : i32, i64, i16) {
+   omp.terminator
+ }
+
  // CHECK: omp.parallel allocate(%{{.*}} : memref<i32> -> %{{.*}} : memref<i32>)
  omp.parallel allocate(%data_var : memref<i32> -> %data_var : memref<i32>) {
    omp.terminator

diff  --git a/mlir/test/Target/LLVMIR/openmp-todo.mlir b/mlir/test/Target/LLVMIR/openmp-todo.mlir
index 09919ae4f5267..ca5ec559926a9 100644
--- a/mlir/test/Target/LLVMIR/openmp-todo.mlir
+++ b/mlir/test/Target/LLVMIR/openmp-todo.mlir
@@ -443,6 +443,17 @@ llvm.func @teams_num_teams_multi_dim(%lb : i32, %ub : i32) {
 
 // -----
 
+llvm.func @parallel_num_threads_multi_dim(%lb : i32, %ub : i32) {
+  // expected-error at below {{not yet implemented: Unhandled clause num_threads with multi-dimensional values in omp.parallel operation}}
+  // expected-error at below {{LLVM Translation failed for operation: omp.parallel}}
+  omp.parallel num_threads(%lb, %ub : i32, i32) {
+    omp.terminator
+  }
+  llvm.return
+}
+
+// -----
+
 llvm.func @wsloop_allocate(%lb : i32, %ub : i32, %step : i32, %x : !llvm.ptr) {
   // expected-error at below {{not yet implemented: Unhandled clause allocate in omp.wsloop operation}}
   // expected-error at below {{LLVM Translation failed for operation: omp.wsloop}}


        


More information about the Mlir-commits mailing list