[llvm-branch-commits] [flang] [llvm] [Flang] Add lowering from flang to mlir for num_threads (PR #175792)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 13 08:31:35 PST 2026


https://github.com/skc7 created https://github.com/llvm/llvm-project/pull/175792

None

>From 61f2ce0f8fafad3ca2e7cd09d4c5f36e1c9c972b Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 13 Jan 2026 21:39:41 +0530
Subject: [PATCH] [Flang] Add lowering from flang to mlir for num_threads

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp   | 20 ++++++-
 flang/lib/Lower/OpenMP/Clauses.cpp           | 13 ++++-
 flang/test/Lower/OpenMP/num-threads-dims.f90 | 61 ++++++++++++++++++++
 llvm/include/llvm/Frontend/OpenMP/ClauseT.h  |  4 +-
 4 files changed, 91 insertions(+), 7 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/num-threads-dims.f90

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index abaeaa90f80be..b649bfbc3a8c5 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -515,9 +515,23 @@ bool ClauseProcessor::processNumThreads(
     lower::StatementContext &stmtCtx,
     mlir::omp::NumThreadsClauseOps &result) const {
   if (auto *clause = findUniqueClause<omp::clause::NumThreads>()) {
-    // OMPIRBuilder expects `NUM_THREADS` clause as a `Value`.
-    result.numThreadsDimsValues.push_back(
-        fir::getBase(converter.genExprValue(clause->v, stmtCtx)));
+    // The num_threads clause accepts a list of values.
+    // With dims modifier (OpenMP 6.1): multiple values for multi-dimensional
+    // Without dims modifier: single value
+    assert(!clause->v.empty());
+
+    // If multiple values, this indicates dims modifier is present
+    if (clause->v.size() > 1) {
+      fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+      result.numThreadsNumDims =
+          firOpBuilder.getI64IntegerAttr(clause->v.size());
+    }
+
+    // Populate all values
+    for (const auto &val : clause->v) {
+      result.numThreadsDimsValues.push_back(
+          fir::getBase(converter.genExprValue(val, stmtCtx)));
+    }
     return true;
   }
   return false;
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index a2716fb22a75c..b2779d8e5d0cc 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1307,9 +1307,16 @@ NumTeams make(const parser::OmpClause::NumTeams &inp,
 NumThreads make(const parser::OmpClause::NumThreads &inp,
                 semantics::SemanticsContext &semaCtx) {
   // inp.v -> parser::OmpNumThreadsClause
-  auto &t1 = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
-  assert(!t1.empty());
-  return NumThreads{/*Nthreads=*/makeExpr(t1.front(), semaCtx)};
+  // With dims modifier (OpenMP 6.1): multiple values for multi-dimensional grid
+  // Without dims modifier: single value
+  auto &values = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
+  assert(!values.empty());
+
+  List<NumThreads::Nthreads> v;
+  for (const auto &val : values) {
+    v.push_back(makeExpr(val, semaCtx));
+  }
+  return NumThreads{/*Nthreads=*/v};
 }
 
 // OmpxAttribute: empty
diff --git a/flang/test/Lower/OpenMP/num-threads-dims.f90 b/flang/test/Lower/OpenMP/num-threads-dims.f90
new file mode 100644
index 0000000000000..71d04a745b1dc
--- /dev/null
+++ b/flang/test/Lower/OpenMP/num-threads-dims.f90
@@ -0,0 +1,61 @@
+! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=61 %s -o - | FileCheck %s
+
+!===============================================================================
+! `num_threads` clause with dims modifier (OpenMP 6.1)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPparallel_numthreads_dims4
+subroutine parallel_numthreads_dims4()
+  ! CHECK: omp.parallel
+  ! CHECK-SAME: num_threads(dims(4): %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}} : i32)
+  !$omp parallel num_threads(dims(4): 4, 5, 6, 7)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end parallel
+end subroutine parallel_numthreads_dims4
+
+! CHECK-LABEL: func @_QPparallel_numthreads_dims2
+subroutine parallel_numthreads_dims2()
+  ! CHECK: omp.parallel
+  ! CHECK-SAME: num_threads(dims(2): %{{.*}}, %{{.*}} : i32)
+  !$omp parallel num_threads(dims(2): 8, 4)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end parallel
+end subroutine parallel_numthreads_dims2
+
+! CHECK-LABEL: func @_QPparallel_numthreads_dims_var
+subroutine parallel_numthreads_dims_var(a, b, c)
+  integer, intent(in) :: a, b, c
+  ! CHECK: omp.parallel
+  ! CHECK-SAME: num_threads(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+  !$omp parallel num_threads(dims(3): a, b, c)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end parallel
+end subroutine parallel_numthreads_dims_var
+
+!===============================================================================
+! `num_threads` clause without dims modifier (legacy)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPparallel_numthreads_legacy
+subroutine parallel_numthreads_legacy(n)
+  integer, intent(in) :: n
+  ! CHECK: omp.parallel
+  ! CHECK-SAME: num_threads(%{{.*}} : i32)
+  !$omp parallel num_threads(n)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end parallel
+end subroutine parallel_numthreads_legacy
+
+! CHECK-LABEL: func @_QPparallel_numthreads_const
+subroutine parallel_numthreads_const()
+  ! CHECK: omp.parallel
+  ! CHECK-SAME: num_threads(%{{.*}} : i32)
+  !$omp parallel num_threads(16)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end parallel
+end subroutine parallel_numthreads_const
diff --git a/llvm/include/llvm/Frontend/OpenMP/ClauseT.h b/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
index 7543f27136e7d..7fade94c2ce30 100644
--- a/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
+++ b/llvm/include/llvm/Frontend/OpenMP/ClauseT.h
@@ -1020,11 +1020,13 @@ struct NumTeamsT {
 };
 
 // V5.2: [10.1.2] `num_threads` clause
+// V6.1: Extended with dims modifier support
 template <typename T, typename I, typename E> //
 struct NumThreadsT {
   using Nthreads = E;
+  // Changed to list to support dims modifier with multiple values (OpenMP 6.1)
   using WrapperTrait = std::true_type;
-  Nthreads v;
+  ListT<Nthreads> v;
 };
 
 template <typename T, typename I, typename E> //



More information about the llvm-branch-commits mailing list