[llvm-branch-commits] [flang] [FLANG] Add flang to mlir lowering for num_teams (PR #175790)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 13 21:58:35 PST 2026


https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/175790

>From 9ced34d1958c0da71fc8b87f20671acab0e4cd54 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 13 Jan 2026 21:24:34 +0530
Subject: [PATCH] [FLANG] Add flang to mlir lowering for num_teams

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 37 ++++++++++-----
 flang/lib/Lower/OpenMP/Clauses.cpp         | 27 +++++++++--
 flang/lib/Lower/OpenMP/OpenMP.cpp          | 18 ++++++--
 flang/test/Lower/OpenMP/num-teams-dims.f90 | 52 ++++++++++++++++++++++
 4 files changed, 117 insertions(+), 17 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/num-teams-dims.f90

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index b923e415231d6..579ce359ed357 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -495,17 +495,34 @@ bool ClauseProcessor::processSizes(StatementContext &stmtCtx,
 bool ClauseProcessor::processNumTeams(
     lower::StatementContext &stmtCtx,
     mlir::omp::NumTeamsClauseOps &result) const {
-  // TODO Get lower and upper bounds for num_teams when parser is updated to
-  // accept both.
   if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
-    // The num_teams directive accepts a list of team lower/upper bounds.
-    // This is an extension to support grid specification for ompx_bare.
-    // Here, only expect a single element in the list.
-    assert(clause->v.size() == 1);
-    // auto lowerBound = std::get<std::optional<ExprTy>>(clause->v[0]->t);
-    auto &upperBound = std::get<ExprTy>(clause->v[0].t);
-    result.numTeamsUpper =
-        fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+    // The num_teams clause accepts a list of upper bounds.
+    // With dims modifier: multiple upper bounds for multi-dimensional grid
+    // Without dims modifier: single Range with optional lower/upper bounds
+    assert(!clause->v.empty());
+
+    // Check if dims modifier is present (indicated by having multiple elements
+    // in the list, or single element without lower bound but with multiple
+    // upper bounds from dims modifier parsing)
+    if (clause->v.size() > 1) {
+      // Dims modifier case: multiple upper bounds
+      fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+      result.numTeamsNumDims = firOpBuilder.getI64IntegerAttr(clause->v.size());
+      for (const auto &range : clause->v) {
+        auto &upperBound = std::get<ExprTy>(range.t);
+        result.numTeamsDimsValues.push_back(
+            fir::getBase(converter.genExprValue(upperBound, stmtCtx)));
+      }
+    } else {
+      // Legacy case: single element with optional lower and upper bounds
+      auto &lowerBound = std::get<std::optional<ExprTy>>(clause->v[0].t);
+      auto &upperBound = std::get<ExprTy>(clause->v[0].t);
+      if (lowerBound)
+        result.numTeamsLower =
+            fir::getBase(converter.genExprValue(*lowerBound, stmtCtx));
+      result.numTeamsUpper =
+          fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+    }
     return true;
   }
   return false;
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index a2716fb22a75c..01d0ff963ecf3 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1297,10 +1297,29 @@ NumTasks make(const parser::OmpClause::NumTasks &inp,
 NumTeams make(const parser::OmpClause::NumTeams &inp,
               semantics::SemanticsContext &semaCtx) {
   // inp.v -> parser::OmpNumTeamsClause
-  auto &t1 = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
-  assert(!t1.empty());
-  List<NumTeams::Range> v{{{/*LowerBound=*/std::nullopt,
-                            /*UpperBound=*/makeExpr(t1.front(), semaCtx)}}};
+  auto &mods = semantics::OmpGetModifiers(inp.v);
+  auto *dims = semantics::OmpGetUniqueModifier<parser::OmpDimsModifier>(mods);
+  auto *lowerBound =
+      semantics::OmpGetUniqueModifier<parser::OmpLowerBound>(mods);
+  auto &values = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
+  assert(!values.empty());
+
+  // With dims modifier: create Range for each value (all upper bounds)
+  // The dims modifier value is stored as the list size matching dims count.
+  // Without dims modifier: single Range with optional lower bound
+  if (dims) {
+    List<NumTeams::Range> v;
+    for (const auto &val : values) {
+      v.push_back(NumTeams::Range{{/*LowerBound=*/std::nullopt,
+                                   /*UpperBound=*/makeExpr(val, semaCtx)}});
+    }
+    return NumTeams{/*List=*/v};
+  }
+
+  // Without dims modifier: single element with optional lower bound
+  auto lb = maybeApplyToV(makeExprFn(semaCtx), lowerBound);
+  List<NumTeams::Range> v{{{/*LowerBound=*/lb,
+                            /*UpperBound=*/makeExpr(values.front(), semaCtx)}}};
   return NumTeams{/*List=*/v};
 }
 
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 989e370870f33..e7a3cfcf52cd2 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -99,6 +99,10 @@ class HostEvalInfo {
     if (ops.numTeamsUpper)
       vars.push_back(ops.numTeamsUpper);
 
+    // num_teams with dims modifier (OpenMP 6.1)
+    for (mlir::Value val : ops.numTeamsDimsValues)
+      vars.push_back(val);
+
     if (ops.numThreads)
       vars.push_back(ops.numThreads);
 
@@ -115,8 +119,8 @@ class HostEvalInfo {
     assert(args.size() ==
                ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
                    ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
-                   (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
-                   (ops.threadLimit ? 1 : 0) &&
+                   (ops.numTeamsUpper ? 1 : 0) + ops.numTeamsDimsValues.size() +
+                   (ops.numThreads ? 1 : 0) + (ops.threadLimit ? 1 : 0) &&
            "invalid block argument list");
     int argIndex = 0;
     for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
@@ -134,6 +138,10 @@ class HostEvalInfo {
     if (ops.numTeamsUpper)
       ops.numTeamsUpper = args[argIndex++];
 
+    // num_teams with dims modifier (OpenMP 6.1)
+    for (size_t i = 0; i < ops.numTeamsDimsValues.size(); ++i)
+      ops.numTeamsDimsValues[i] = args[argIndex++];
+
     if (ops.numThreads)
       ops.numThreads = args[argIndex++];
 
@@ -185,11 +193,15 @@ class HostEvalInfo {
   /// \returns whether an update was performed. If not, these clauses were not
   ///          evaluated in the host device.
   bool apply(mlir::omp::TeamsOperands &clauseOps) {
-    if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit)
+    if (!ops.numTeamsLower && !ops.numTeamsUpper &&
+        ops.numTeamsDimsValues.empty() && !ops.threadLimit)
       return false;
 
     clauseOps.numTeamsLower = ops.numTeamsLower;
     clauseOps.numTeamsUpper = ops.numTeamsUpper;
+    // num_teams with dims modifier (OpenMP 6.1)
+    clauseOps.numTeamsDimsValues = ops.numTeamsDimsValues;
+    clauseOps.numTeamsNumDims = ops.numTeamsNumDims;
     clauseOps.threadLimit = ops.threadLimit;
     return true;
   }
diff --git a/flang/test/Lower/OpenMP/num-teams-dims.f90 b/flang/test/Lower/OpenMP/num-teams-dims.f90
new file mode 100644
index 0000000000000..fd5d5ba40c804
--- /dev/null
+++ b/flang/test/Lower/OpenMP/num-teams-dims.f90
@@ -0,0 +1,52 @@
+! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=61 %s -o - | FileCheck %s
+
+!===============================================================================
+! `num_teams` clause with dims modifier (OpenMP 6.1)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_numteams_dims2
+subroutine teams_numteams_dims2()
+  ! CHECK: omp.teams
+  ! CHECK-SAME: num_teams(dims(2): %{{.*}}, %{{.*}} : i32)
+  !$omp teams num_teams(dims(2): 10, 4)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+end subroutine teams_numteams_dims2
+
+! CHECK-LABEL: func @_QPteams_numteams_dims3
+subroutine teams_numteams_dims3()
+  ! CHECK: omp.teams
+  ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+  !$omp teams num_teams(dims(3): 8, 4, 2)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+end subroutine teams_numteams_dims3
+
+! CHECK-LABEL: func @_QPteams_numteams_dims_var
+subroutine teams_numteams_dims_var(a, b, c)
+  integer, intent(in) :: a, b, c
+  ! CHECK: omp.teams
+  ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+  !$omp teams num_teams(dims(3): a, b, c)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+end subroutine teams_numteams_dims_var
+
+!===============================================================================
+! `num_teams` clause with lower bound (legacy, without dims)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_numteams_lower_upper
+subroutine teams_numteams_lower_upper(lower, upper)
+  integer, intent(in) :: lower, upper
+  ! CHECK: omp.teams
+  ! CHECK-SAME: num_teams(%{{.*}} : i32 to %{{.*}} : i32)
+  !$omp teams num_teams(lower: upper)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+end subroutine teams_numteams_lower_upper
+



More information about the llvm-branch-commits mailing list