[llvm-branch-commits] [flang] [FLANG] Add flang to mlir lowering for num_teams (PR #175790)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 13 21:58:35 PST 2026
https://github.com/skc7 updated https://github.com/llvm/llvm-project/pull/175790
>From 9ced34d1958c0da71fc8b87f20671acab0e4cd54 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 13 Jan 2026 21:24:34 +0530
Subject: [PATCH] [FLANG] Add flang to mlir lowering for num_teams
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 37 ++++++++++-----
flang/lib/Lower/OpenMP/Clauses.cpp | 27 +++++++++--
flang/lib/Lower/OpenMP/OpenMP.cpp | 18 ++++++--
flang/test/Lower/OpenMP/num-teams-dims.f90 | 52 ++++++++++++++++++++++
4 files changed, 117 insertions(+), 17 deletions(-)
create mode 100644 flang/test/Lower/OpenMP/num-teams-dims.f90
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index b923e415231d6..579ce359ed357 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -495,17 +495,34 @@ bool ClauseProcessor::processSizes(StatementContext &stmtCtx,
bool ClauseProcessor::processNumTeams(
lower::StatementContext &stmtCtx,
mlir::omp::NumTeamsClauseOps &result) const {
- // TODO Get lower and upper bounds for num_teams when parser is updated to
- // accept both.
if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
- // The num_teams directive accepts a list of team lower/upper bounds.
- // This is an extension to support grid specification for ompx_bare.
- // Here, only expect a single element in the list.
- assert(clause->v.size() == 1);
- // auto lowerBound = std::get<std::optional<ExprTy>>(clause->v[0]->t);
- auto &upperBound = std::get<ExprTy>(clause->v[0].t);
- result.numTeamsUpper =
- fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+ // The num_teams clause accepts a list of upper bounds.
+ // With dims modifier: multiple upper bounds for multi-dimensional grid
+ // Without dims modifier: single Range with optional lower/upper bounds
+ assert(!clause->v.empty());
+
+ // Check if dims modifier is present (indicated by having multiple elements
+ // in the list, or single element without lower bound but with multiple
+ // upper bounds from dims modifier parsing)
+ if (clause->v.size() > 1) {
+ // Dims modifier case: multiple upper bounds
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ result.numTeamsNumDims = firOpBuilder.getI64IntegerAttr(clause->v.size());
+ for (const auto &range : clause->v) {
+ auto &upperBound = std::get<ExprTy>(range.t);
+ result.numTeamsDimsValues.push_back(
+ fir::getBase(converter.genExprValue(upperBound, stmtCtx)));
+ }
+ } else {
+ // Legacy case: single element with optional lower and upper bounds
+ auto &lowerBound = std::get<std::optional<ExprTy>>(clause->v[0].t);
+ auto &upperBound = std::get<ExprTy>(clause->v[0].t);
+ if (lowerBound)
+ result.numTeamsLower =
+ fir::getBase(converter.genExprValue(*lowerBound, stmtCtx));
+ result.numTeamsUpper =
+ fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+ }
return true;
}
return false;
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index a2716fb22a75c..01d0ff963ecf3 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1297,10 +1297,29 @@ NumTasks make(const parser::OmpClause::NumTasks &inp,
NumTeams make(const parser::OmpClause::NumTeams &inp,
semantics::SemanticsContext &semaCtx) {
// inp.v -> parser::OmpNumTeamsClause
- auto &t1 = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
- assert(!t1.empty());
- List<NumTeams::Range> v{{{/*LowerBound=*/std::nullopt,
- /*UpperBound=*/makeExpr(t1.front(), semaCtx)}}};
+ auto &mods = semantics::OmpGetModifiers(inp.v);
+ auto *dims = semantics::OmpGetUniqueModifier<parser::OmpDimsModifier>(mods);
+ auto *lowerBound =
+ semantics::OmpGetUniqueModifier<parser::OmpLowerBound>(mods);
+ auto &values = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
+ assert(!values.empty());
+
+ // With dims modifier: create Range for each value (all upper bounds)
+ // The dims modifier value is stored as the list size matching dims count.
+ // Without dims modifier: single Range with optional lower bound
+ if (dims) {
+ List<NumTeams::Range> v;
+ for (const auto &val : values) {
+ v.push_back(NumTeams::Range{{/*LowerBound=*/std::nullopt,
+ /*UpperBound=*/makeExpr(val, semaCtx)}});
+ }
+ return NumTeams{/*List=*/v};
+ }
+
+ // Without dims modifier: single element with optional lower bound
+ auto lb = maybeApplyToV(makeExprFn(semaCtx), lowerBound);
+ List<NumTeams::Range> v{{{/*LowerBound=*/lb,
+ /*UpperBound=*/makeExpr(values.front(), semaCtx)}}};
return NumTeams{/*List=*/v};
}
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 989e370870f33..e7a3cfcf52cd2 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -99,6 +99,10 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
vars.push_back(ops.numTeamsUpper);
+ // num_teams with dims modifier (OpenMP 6.1)
+ for (mlir::Value val : ops.numTeamsDimsValues)
+ vars.push_back(val);
+
if (ops.numThreads)
vars.push_back(ops.numThreads);
@@ -115,8 +119,8 @@ class HostEvalInfo {
assert(args.size() ==
ops.loopLowerBounds.size() + ops.loopUpperBounds.size() +
ops.loopSteps.size() + (ops.numTeamsLower ? 1 : 0) +
- (ops.numTeamsUpper ? 1 : 0) + (ops.numThreads ? 1 : 0) +
- (ops.threadLimit ? 1 : 0) &&
+ (ops.numTeamsUpper ? 1 : 0) + ops.numTeamsDimsValues.size() +
+ (ops.numThreads ? 1 : 0) + (ops.threadLimit ? 1 : 0) &&
"invalid block argument list");
int argIndex = 0;
for (size_t i = 0; i < ops.loopLowerBounds.size(); ++i)
@@ -134,6 +138,10 @@ class HostEvalInfo {
if (ops.numTeamsUpper)
ops.numTeamsUpper = args[argIndex++];
+ // num_teams with dims modifier (OpenMP 6.1)
+ for (size_t i = 0; i < ops.numTeamsDimsValues.size(); ++i)
+ ops.numTeamsDimsValues[i] = args[argIndex++];
+
if (ops.numThreads)
ops.numThreads = args[argIndex++];
@@ -185,11 +193,15 @@ class HostEvalInfo {
/// \returns whether an update was performed. If not, these clauses were not
/// evaluated in the host device.
bool apply(mlir::omp::TeamsOperands &clauseOps) {
- if (!ops.numTeamsLower && !ops.numTeamsUpper && !ops.threadLimit)
+ if (!ops.numTeamsLower && !ops.numTeamsUpper &&
+ ops.numTeamsDimsValues.empty() && !ops.threadLimit)
return false;
clauseOps.numTeamsLower = ops.numTeamsLower;
clauseOps.numTeamsUpper = ops.numTeamsUpper;
+ // num_teams with dims modifier (OpenMP 6.1)
+ clauseOps.numTeamsDimsValues = ops.numTeamsDimsValues;
+ clauseOps.numTeamsNumDims = ops.numTeamsNumDims;
clauseOps.threadLimit = ops.threadLimit;
return true;
}
diff --git a/flang/test/Lower/OpenMP/num-teams-dims.f90 b/flang/test/Lower/OpenMP/num-teams-dims.f90
new file mode 100644
index 0000000000000..fd5d5ba40c804
--- /dev/null
+++ b/flang/test/Lower/OpenMP/num-teams-dims.f90
@@ -0,0 +1,52 @@
+! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=61 %s -o - | FileCheck %s
+
+!===============================================================================
+! `num_teams` clause with dims modifier (OpenMP 6.1)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_numteams_dims2
+subroutine teams_numteams_dims2()
+ ! CHECK: omp.teams
+ ! CHECK-SAME: num_teams(dims(2): %{{.*}}, %{{.*}} : i32)
+ !$omp teams num_teams(dims(2): 10, 4)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end teams
+end subroutine teams_numteams_dims2
+
+! CHECK-LABEL: func @_QPteams_numteams_dims3
+subroutine teams_numteams_dims3()
+ ! CHECK: omp.teams
+ ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+ !$omp teams num_teams(dims(3): 8, 4, 2)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end teams
+end subroutine teams_numteams_dims3
+
+! CHECK-LABEL: func @_QPteams_numteams_dims_var
+subroutine teams_numteams_dims_var(a, b, c)
+ integer, intent(in) :: a, b, c
+ ! CHECK: omp.teams
+ ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+ !$omp teams num_teams(dims(3): a, b, c)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end teams
+end subroutine teams_numteams_dims_var
+
+!===============================================================================
+! `num_teams` clause with lower bound (legacy, without dims)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_numteams_lower_upper
+subroutine teams_numteams_lower_upper(lower, upper)
+ integer, intent(in) :: lower, upper
+ ! CHECK: omp.teams
+ ! CHECK-SAME: num_teams(%{{.*}} : i32 to %{{.*}} : i32)
+ !$omp teams num_teams(lower: upper)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end teams
+end subroutine teams_numteams_lower_upper
+
More information about the llvm-branch-commits
mailing list