[llvm-branch-commits] [flang] [FLANG] Add flang to mlir lowering for num_teams (PR #175790)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Tue Jan 13 08:29:35 PST 2026


https://github.com/skc7 created https://github.com/llvm/llvm-project/pull/175790

None

>From 75cf8e2211eb9dce46ed9f6f5e57643efaddf280 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 13 Jan 2026 21:24:34 +0530
Subject: [PATCH] [FLANG] Add flang to mlir lowering for num_teams

---
 flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 37 ++++++++++-----
 flang/lib/Lower/OpenMP/Clauses.cpp         | 26 +++++++++--
 flang/test/Lower/OpenMP/num-teams-dims.f90 | 52 ++++++++++++++++++++++
 3 files changed, 101 insertions(+), 14 deletions(-)
 create mode 100644 flang/test/Lower/OpenMP/num-teams-dims.f90

diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index b923e415231d6..579ce359ed357 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -495,17 +495,34 @@ bool ClauseProcessor::processSizes(StatementContext &stmtCtx,
 bool ClauseProcessor::processNumTeams(
     lower::StatementContext &stmtCtx,
     mlir::omp::NumTeamsClauseOps &result) const {
-  // TODO Get lower and upper bounds for num_teams when parser is updated to
-  // accept both.
   if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
-    // The num_teams directive accepts a list of team lower/upper bounds.
-    // This is an extension to support grid specification for ompx_bare.
-    // Here, only expect a single element in the list.
-    assert(clause->v.size() == 1);
-    // auto lowerBound = std::get<std::optional<ExprTy>>(clause->v[0]->t);
-    auto &upperBound = std::get<ExprTy>(clause->v[0].t);
-    result.numTeamsUpper =
-        fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+    // The num_teams clause accepts a list of upper bounds.
+    // With dims modifier: multiple upper bounds for multi-dimensional grid
+    // Without dims modifier: single Range with optional lower/upper bounds
+    assert(!clause->v.empty());
+
+    // Check if dims modifier is present (indicated by having multiple elements
+    // in the list, or single element without lower bound but with multiple
+    // upper bounds from dims modifier parsing)
+    if (clause->v.size() > 1) {
+      // Dims modifier case: multiple upper bounds
+      fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+      result.numTeamsNumDims = firOpBuilder.getI64IntegerAttr(clause->v.size());
+      for (const auto &range : clause->v) {
+        auto &upperBound = std::get<ExprTy>(range.t);
+        result.numTeamsDimsValues.push_back(
+            fir::getBase(converter.genExprValue(upperBound, stmtCtx)));
+      }
+    } else {
+      // Legacy case: single element with optional lower and upper bounds
+      auto &lowerBound = std::get<std::optional<ExprTy>>(clause->v[0].t);
+      auto &upperBound = std::get<ExprTy>(clause->v[0].t);
+      if (lowerBound)
+        result.numTeamsLower =
+            fir::getBase(converter.genExprValue(*lowerBound, stmtCtx));
+      result.numTeamsUpper =
+          fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+    }
     return true;
   }
   return false;
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index a2716fb22a75c..a2a952c2c2ba8 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1297,10 +1297,28 @@ NumTasks make(const parser::OmpClause::NumTasks &inp,
 NumTeams make(const parser::OmpClause::NumTeams &inp,
               semantics::SemanticsContext &semaCtx) {
   // inp.v -> parser::OmpNumTeamsClause
-  auto &t1 = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
-  assert(!t1.empty());
-  List<NumTeams::Range> v{{{/*LowerBound=*/std::nullopt,
-                            /*UpperBound=*/makeExpr(t1.front(), semaCtx)}}};
+  auto &mods = semantics::OmpGetModifiers(inp.v);
+  auto *dims = semantics::OmpGetUniqueModifier<parser::OmpDimsModifier>(mods);
+  auto *lowerBound = semantics::OmpGetUniqueModifier<parser::OmpLowerBound>(mods);
+  auto &values = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
+  assert(!values.empty());
+
+  // With dims modifier: create Range for each value (all upper bounds)
+  // The dims modifier value is stored as the list size matching dims count.
+  // Without dims modifier: single Range with optional lower bound
+  if (dims) {
+    List<NumTeams::Range> v;
+    for (const auto &val : values) {
+      v.push_back(NumTeams::Range{{/*LowerBound=*/std::nullopt,
+                                   /*UpperBound=*/makeExpr(val, semaCtx)}});
+    }
+    return NumTeams{/*List=*/v};
+  }
+
+  // Without dims modifier: single element with optional lower bound
+  auto lb = maybeApplyToV(makeExprFn(semaCtx), lowerBound);
+  List<NumTeams::Range> v{{{/*LowerBound=*/lb,
+                            /*UpperBound=*/makeExpr(values.front(), semaCtx)}}};
   return NumTeams{/*List=*/v};
 }
 
diff --git a/flang/test/Lower/OpenMP/num-teams-dims.f90 b/flang/test/Lower/OpenMP/num-teams-dims.f90
new file mode 100644
index 0000000000000..fd5d5ba40c804
--- /dev/null
+++ b/flang/test/Lower/OpenMP/num-teams-dims.f90
@@ -0,0 +1,52 @@
+! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=61 %s -o - | FileCheck %s
+
+!===============================================================================
+! `num_teams` clause with dims modifier (OpenMP 6.1)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_numteams_dims2
+subroutine teams_numteams_dims2()
+  ! CHECK: omp.teams
+  ! CHECK-SAME: num_teams(dims(2): %{{.*}}, %{{.*}} : i32)
+  !$omp teams num_teams(dims(2): 10, 4)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+end subroutine teams_numteams_dims2
+
+! CHECK-LABEL: func @_QPteams_numteams_dims3
+subroutine teams_numteams_dims3()
+  ! CHECK: omp.teams
+  ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+  !$omp teams num_teams(dims(3): 8, 4, 2)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+end subroutine teams_numteams_dims3
+
+! CHECK-LABEL: func @_QPteams_numteams_dims_var
+subroutine teams_numteams_dims_var(a, b, c)
+  integer, intent(in) :: a, b, c
+  ! CHECK: omp.teams
+  ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+  !$omp teams num_teams(dims(3): a, b, c)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+end subroutine teams_numteams_dims_var
+
+!===============================================================================
+! `num_teams` clause with lower bound (legacy, without dims)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_numteams_lower_upper
+subroutine teams_numteams_lower_upper(lower, upper)
+  integer, intent(in) :: lower, upper
+  ! CHECK: omp.teams
+  ! CHECK-SAME: num_teams(%{{.*}} : i32 to %{{.*}} : i32)
+  !$omp teams num_teams(lower: upper)
+  call f1()
+  ! CHECK: omp.terminator
+  !$omp end teams
+end subroutine teams_numteams_lower_upper
+



More information about the llvm-branch-commits mailing list