[llvm-branch-commits] [flang] [FLANG] Add flang to mlir lowering for num_teams (PR #175790)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Tue Jan 13 08:29:35 PST 2026
https://github.com/skc7 created https://github.com/llvm/llvm-project/pull/175790
None
>From 75cf8e2211eb9dce46ed9f6f5e57643efaddf280 Mon Sep 17 00:00:00 2001
From: skc7 <Krishna.Sankisa at amd.com>
Date: Tue, 13 Jan 2026 21:24:34 +0530
Subject: [PATCH] [FLANG] Add flang to mlir lowering for num_teams
---
flang/lib/Lower/OpenMP/ClauseProcessor.cpp | 37 ++++++++++-----
flang/lib/Lower/OpenMP/Clauses.cpp | 26 +++++++++--
flang/test/Lower/OpenMP/num-teams-dims.f90 | 52 ++++++++++++++++++++++
3 files changed, 101 insertions(+), 14 deletions(-)
create mode 100644 flang/test/Lower/OpenMP/num-teams-dims.f90
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index b923e415231d6..579ce359ed357 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -495,17 +495,34 @@ bool ClauseProcessor::processSizes(StatementContext &stmtCtx,
bool ClauseProcessor::processNumTeams(
lower::StatementContext &stmtCtx,
mlir::omp::NumTeamsClauseOps &result) const {
- // TODO Get lower and upper bounds for num_teams when parser is updated to
- // accept both.
if (auto *clause = findUniqueClause<omp::clause::NumTeams>()) {
- // The num_teams directive accepts a list of team lower/upper bounds.
- // This is an extension to support grid specification for ompx_bare.
- // Here, only expect a single element in the list.
- assert(clause->v.size() == 1);
- // auto lowerBound = std::get<std::optional<ExprTy>>(clause->v[0]->t);
- auto &upperBound = std::get<ExprTy>(clause->v[0].t);
- result.numTeamsUpper =
- fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+ // The num_teams clause accepts a list of upper bounds.
+ // With dims modifier: multiple upper bounds for multi-dimensional grid
+ // Without dims modifier: single Range with optional lower/upper bounds
+ assert(!clause->v.empty());
+
+ // Check if dims modifier is present (indicated by having multiple elements
+ // in the list, or single element without lower bound but with multiple
+ // upper bounds from dims modifier parsing)
+ if (clause->v.size() > 1) {
+ // Dims modifier case: multiple upper bounds
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ result.numTeamsNumDims = firOpBuilder.getI64IntegerAttr(clause->v.size());
+ for (const auto &range : clause->v) {
+ auto &upperBound = std::get<ExprTy>(range.t);
+ result.numTeamsDimsValues.push_back(
+ fir::getBase(converter.genExprValue(upperBound, stmtCtx)));
+ }
+ } else {
+ // Legacy case: single element with optional lower and upper bounds
+ auto &lowerBound = std::get<std::optional<ExprTy>>(clause->v[0].t);
+ auto &upperBound = std::get<ExprTy>(clause->v[0].t);
+ if (lowerBound)
+ result.numTeamsLower =
+ fir::getBase(converter.genExprValue(*lowerBound, stmtCtx));
+ result.numTeamsUpper =
+ fir::getBase(converter.genExprValue(upperBound, stmtCtx));
+ }
return true;
}
return false;
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index a2716fb22a75c..a2a952c2c2ba8 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1297,10 +1297,28 @@ NumTasks make(const parser::OmpClause::NumTasks &inp,
NumTeams make(const parser::OmpClause::NumTeams &inp,
semantics::SemanticsContext &semaCtx) {
// inp.v -> parser::OmpNumTeamsClause
- auto &t1 = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
- assert(!t1.empty());
- List<NumTeams::Range> v{{{/*LowerBound=*/std::nullopt,
- /*UpperBound=*/makeExpr(t1.front(), semaCtx)}}};
+ auto &mods = semantics::OmpGetModifiers(inp.v);
+ auto *dims = semantics::OmpGetUniqueModifier<parser::OmpDimsModifier>(mods);
+ auto *lowerBound = semantics::OmpGetUniqueModifier<parser::OmpLowerBound>(mods);
+ auto &values = std::get<std::list<parser::ScalarIntExpr>>(inp.v.t);
+ assert(!values.empty());
+
+ // With dims modifier: create Range for each value (all upper bounds)
+ // The dims modifier value is stored as the list size matching dims count.
+ // Without dims modifier: single Range with optional lower bound
+ if (dims) {
+ List<NumTeams::Range> v;
+ for (const auto &val : values) {
+ v.push_back(NumTeams::Range{{/*LowerBound=*/std::nullopt,
+ /*UpperBound=*/makeExpr(val, semaCtx)}});
+ }
+ return NumTeams{/*List=*/v};
+ }
+
+ // Without dims modifier: single element with optional lower bound
+ auto lb = maybeApplyToV(makeExprFn(semaCtx), lowerBound);
+ List<NumTeams::Range> v{{{/*LowerBound=*/lb,
+ /*UpperBound=*/makeExpr(values.front(), semaCtx)}}};
return NumTeams{/*List=*/v};
}
diff --git a/flang/test/Lower/OpenMP/num-teams-dims.f90 b/flang/test/Lower/OpenMP/num-teams-dims.f90
new file mode 100644
index 0000000000000..fd5d5ba40c804
--- /dev/null
+++ b/flang/test/Lower/OpenMP/num-teams-dims.f90
@@ -0,0 +1,52 @@
+! RUN: %flang_fc1 -emit-hlfir %openmp_flags -fopenmp-version=61 %s -o - | FileCheck %s
+
+!===============================================================================
+! `num_teams` clause with dims modifier (OpenMP 6.1)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_numteams_dims2
+subroutine teams_numteams_dims2()
+ ! CHECK: omp.teams
+ ! CHECK-SAME: num_teams(dims(2): %{{.*}}, %{{.*}} : i32)
+ !$omp teams num_teams(dims(2): 10, 4)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end teams
+end subroutine teams_numteams_dims2
+
+! CHECK-LABEL: func @_QPteams_numteams_dims3
+subroutine teams_numteams_dims3()
+ ! CHECK: omp.teams
+ ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+ !$omp teams num_teams(dims(3): 8, 4, 2)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end teams
+end subroutine teams_numteams_dims3
+
+! CHECK-LABEL: func @_QPteams_numteams_dims_var
+subroutine teams_numteams_dims_var(a, b, c)
+ integer, intent(in) :: a, b, c
+ ! CHECK: omp.teams
+ ! CHECK-SAME: num_teams(dims(3): %{{.*}}, %{{.*}}, %{{.*}} : i32)
+ !$omp teams num_teams(dims(3): a, b, c)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end teams
+end subroutine teams_numteams_dims_var
+
+!===============================================================================
+! `num_teams` clause with lower bound (legacy, without dims)
+!===============================================================================
+
+! CHECK-LABEL: func @_QPteams_numteams_lower_upper
+subroutine teams_numteams_lower_upper(lower, upper)
+ integer, intent(in) :: lower, upper
+ ! CHECK: omp.teams
+ ! CHECK-SAME: num_teams(%{{.*}} : i32 to %{{.*}} : i32)
+ !$omp teams num_teams(lower: upper)
+ call f1()
+ ! CHECK: omp.terminator
+ !$omp end teams
+end subroutine teams_numteams_lower_upper
+
More information about the llvm-branch-commits
mailing list