[flang-commits] [flang] [flang][openacc] Add support for force clause for loop collapse (PR #162534)
Susan Tan ス-ザン タン via flang-commits
flang-commits at lists.llvm.org
Wed Oct 8 16:24:28 PDT 2025
https://github.com/SusanTan updated https://github.com/llvm/llvm-project/pull/162534
>From 65b02a4c45431c335021c0131264ffdba65c0524 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Tue, 7 Oct 2025 16:47:15 -0700
Subject: [PATCH 01/10] add initial implementation
---
flang/lib/Lower/Bridge.cpp | 69 +++++++++++++++++++++++++++++++++++--
flang/lib/Lower/OpenACC.cpp | 4 ---
2 files changed, 66 insertions(+), 7 deletions(-)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 780d56f085f69..f75f648fbdfcc 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3190,15 +3190,39 @@ class FirConverter : public Fortran::lower::AbstractConverter {
std::get_if<Fortran::parser::OpenACCCombinedConstruct>(&acc.u);
Fortran::lower::pft::Evaluation *curEval = &getEval();
+ // Determine collapse depth/force and loopCount
+ bool collapseForce = false;
+ uint64_t collapseDepth = 1;
+ uint64_t loopCount = 1;
+ auto parseCollapse = [&](const Fortran::parser::AccClauseList &cl)
+ -> std::pair<bool, uint64_t> {
+ bool force = false;
+ uint64_t depth = 1;
+ for (const Fortran::parser::AccClause &clause : cl.v) {
+ if (const auto *collapseClause =
+ std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
+ const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
+ force = std::get<bool>(arg.t);
+ const auto &intExpr =
+ std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
+ if (const auto *expr = Fortran::semantics::GetExpr(intExpr)) {
+ if (auto v = Fortran::evaluate::ToInt64(*expr))
+ depth = *v;
+ }
+ break;
+ }
+ }
+ return {force, depth};
+ };
if (accLoop || accCombined) {
- uint64_t loopCount;
if (accLoop) {
const Fortran::parser::AccBeginLoopDirective &beginLoopDir =
std::get<Fortran::parser::AccBeginLoopDirective>(accLoop->t);
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+ std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
} else if (accCombined) {
const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
std::get<Fortran::parser::AccBeginCombinedDirective>(
@@ -3206,6 +3230,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
+ std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
}
if (curEval->lowerAsStructured()) {
@@ -3215,8 +3240,46 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
}
- for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
- genFIR(e);
+ // Collect prologue and tail (after-inner) statements if force
+ llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, tail;
+ if (collapseForce && loopCount > 1 && getEval().lowerAsStructured()) {
+ auto hasKids = [](Fortran::lower::pft::Evaluation *ev) -> bool {
+ return ev && ev->hasNestedEvaluations();
+ };
+ Fortran::lower::pft::Evaluation *parent = &getEval();
+ uint64_t levelsToProcess = std::min<uint64_t>(collapseDepth, loopCount);
+ for (uint64_t lvl = 0; lvl + 1 < levelsToProcess; ++lvl) {
+ if (!hasKids(parent)) break;
+ Fortran::lower::pft::Evaluation *childLoop = nullptr;
+ tail.clear();
+ auto &kids = parent->getNestedEvaluations();
+ for (auto it = kids.begin(); it != kids.end(); ++it) {
+ if (it->getIf<Fortran::parser::DoConstruct>()) {
+ childLoop = &*it;
+ for (auto it2 = std::next(it); it2 != kids.end(); ++it2)
+ tail.push_back(&*it2);
+ break;
+ }
+ prologue.push_back(&*it);
+ }
+ if (!childLoop) break;
+ parent = childLoop;
+ }
+ }
+
+ // Prologue sink
+ for (auto *e : prologue)
+ genFIR(*e);
+
+ // Lower the loop body as usual
+ if (curEval && curEval->hasNestedEvaluations()) {
+ for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
+ genFIR(e);
+ }
+
+ // Epilogue sink
+ for (auto *e : tail)
+ genFIR(*e);
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 4a9e49435a907..4653f40e77948 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2406,10 +2406,6 @@ static mlir::acc::LoopOp createLoopOp(
std::get_if<Fortran::parser::AccClause::Collapse>(
&clause.u)) {
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
- const auto &force = std::get<bool>(arg.t);
- if (force)
- TODO(clauseLocation, "OpenACC collapse force modifier");
-
const auto &intExpr =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
const auto *expr = Fortran::semantics::GetExpr(intExpr);
>From 64d01104612b5641c2b8b4685f6d3f82762456ad Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Tue, 7 Oct 2025 16:58:08 -0700
Subject: [PATCH 02/10] tweak
---
flang/lib/Lower/OpenACC.cpp | 22 ++++++++++++++++++----
1 file changed, 18 insertions(+), 4 deletions(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 4653f40e77948..e24e784895fe8 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2144,8 +2144,23 @@ static void processDoLoopBounds(
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(outerDoConstruct)));
} else {
- auto *doCons = crtEval->getIf<Fortran::parser::DoConstruct>();
- assert(doCons && "expect do construct");
+ // Safely locate the next inner DoConstruct within this eval.
+ const Fortran::parser::DoConstruct *doCons = nullptr;
+ if (crtEval && crtEval->hasNestedEvaluations()) {
+ for (Fortran::lower::pft::Evaluation &child :
+ crtEval->getNestedEvaluations()) {
+ if (auto *cand = child.getIf<Fortran::parser::DoConstruct>()) {
+ doCons = cand;
+ // Prepare to descend for the next iteration
+ crtEval = &child;
+ break;
+ }
+ }
+ }
+ if (!doCons) {
+ // No deeper loop; stop collecting collapsed bounds.
+ break;
+ }
loopControl = &*doCons->GetLoopControl();
locs.push_back(converter.genLocation(
Fortran::parser::FindSourceLocation(*doCons)));
@@ -2172,8 +2187,7 @@ static void processDoLoopBounds(
inclusiveBounds.push_back(true);
- if (i < loopsToProcess - 1)
- crtEval = &*std::next(crtEval->getNestedEvaluations().begin());
+ // crtEval already updated when descending; no blind increment here.
}
}
}
>From c7e7321e65fcf9a6138cf526a4c40d5f303f636b Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 8 Oct 2025 08:49:42 -0700
Subject: [PATCH 03/10] tweak
---
flang/lib/Lower/Bridge.cpp | 16 ++++++++++++++--
1 file changed, 14 insertions(+), 2 deletions(-)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index f75f648fbdfcc..b406de9a739ff 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3267,14 +3267,26 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
}
+ // Track sunk evaluations to avoid double-lowering
+ llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
+ for (auto *e : prologue) sunk.insert(e);
+ for (auto *e : tail) sunk.insert(e);
+
// Prologue sink
for (auto *e : prologue)
genFIR(*e);
- // Lower the loop body as usual
+ // Lower the loop body as usual, skipping already-sunk evals
if (curEval && curEval->hasNestedEvaluations()) {
- for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
+ for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) {
+ if (sunk.contains(&e)) continue;
+ genFIR(e);
+ }
+ } else if (getEval().hasNestedEvaluations()) {
+ for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) {
+ if (sunk.contains(&e)) continue;
genFIR(e);
+ }
}
// Epilogue sink
>From 48733f4d7258267ea2cd8b99bfe764f6fd559feb Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 8 Oct 2025 11:27:28 -0700
Subject: [PATCH 04/10] code cleanup
---
flang/lib/Lower/Bridge.cpp | 70 +++++++++++++++++++-------------------
1 file changed, 35 insertions(+), 35 deletions(-)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index b406de9a739ff..32eb382e2c34f 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3240,58 +3240,58 @@ class FirConverter : public Fortran::lower::AbstractConverter {
}
}
- // Collect prologue and tail (after-inner) statements if force
- llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, tail;
- if (collapseForce && loopCount > 1 && getEval().lowerAsStructured()) {
- auto hasKids = [](Fortran::lower::pft::Evaluation *ev) -> bool {
- return ev && ev->hasNestedEvaluations();
- };
+ const bool isStructured = curEval && curEval->lowerAsStructured();
+ if (isStructured && collapseForce && collapseDepth > 1) {
+ // force: collect prologue/epilogue for the first collapseDepth nested loops
+ // and sink them into the innermost loop body at that depth
+ llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, epilogue;
Fortran::lower::pft::Evaluation *parent = &getEval();
- uint64_t levelsToProcess = std::min<uint64_t>(collapseDepth, loopCount);
- for (uint64_t lvl = 0; lvl + 1 < levelsToProcess; ++lvl) {
- if (!hasKids(parent)) break;
- Fortran::lower::pft::Evaluation *childLoop = nullptr;
- tail.clear();
+ Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr;
+ for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) {
+ epilogue.clear();
auto &kids = parent->getNestedEvaluations();
+ // Collect all non-loop statements before the next inner loop as prologue,
+ // then mark remaining siblings as epilogue and descend into the inner loop.
+ Fortran::lower::pft::Evaluation *childLoop = nullptr;
for (auto it = kids.begin(); it != kids.end(); ++it) {
if (it->getIf<Fortran::parser::DoConstruct>()) {
childLoop = &*it;
for (auto it2 = std::next(it); it2 != kids.end(); ++it2)
- tail.push_back(&*it2);
+ epilogue.push_back(&*it2);
break;
}
prologue.push_back(&*it);
}
- if (!childLoop) break;
+ // Semantics guarantees collapseDepth does not exceed nest depth
+ // so childLoop must be found here.
+ assert(childLoop && "Expected inner DoConstruct for collapse");
parent = childLoop;
+ innermostLoopEval = childLoop;
}
- }
- // Track sunk evaluations to avoid double-lowering
- llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
- for (auto *e : prologue) sunk.insert(e);
- for (auto *e : tail) sunk.insert(e);
+ // Track sunk evaluations (avoid double-lowering)
+ llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
+ for (auto *e : prologue) sunk.insert(e);
+ for (auto *e : epilogue) sunk.insert(e);
- // Prologue sink
- for (auto *e : prologue)
- genFIR(*e);
+ auto emit = [&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
+ for (auto *e : lst) genFIR(*e);
+ };
- // Lower the loop body as usual, skipping already-sunk evals
- if (curEval && curEval->hasNestedEvaluations()) {
- for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations()) {
- if (sunk.contains(&e)) continue;
- genFIR(e);
- }
- } else if (getEval().hasNestedEvaluations()) {
- for (Fortran::lower::pft::Evaluation &e : getEval().getNestedEvaluations()) {
- if (sunk.contains(&e)) continue;
+ // Sink prologue
+ emit(prologue);
+
+ // Lower innermost loop body, skipping sunk
+ for (Fortran::lower::pft::Evaluation &e : innermostLoopEval->getNestedEvaluations())
+ if (!sunk.contains(&e)) genFIR(e);
+
+ // Sink epilogue
+ emit(epilogue);
+ } else {
+ // Normal lowering
+ for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
genFIR(e);
- }
}
-
- // Epilogue sink
- for (auto *e : tail)
- genFIR(*e);
localSymbols.popScope();
builder->restoreInsertionPoint(insertPt);
>From 32b5f71aa9b88ff8ab7681ed23f5e939572a6cca Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 8 Oct 2025 11:57:49 -0700
Subject: [PATCH 05/10] add a test
---
flang/lib/Lower/Bridge.cpp | 14 +++----
.../acc-loop-collapse-force-lowering.f90 | 41 +++++++++++++++++++
2 files changed, 48 insertions(+), 7 deletions(-)
create mode 100644 flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 32eb382e2c34f..3d331cdad3d43 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3274,19 +3274,19 @@ class FirConverter : public Fortran::lower::AbstractConverter {
for (auto *e : prologue) sunk.insert(e);
for (auto *e : epilogue) sunk.insert(e);
- auto emit = [&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
- for (auto *e : lst) genFIR(*e);
- };
+ auto sink =
+ [&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
+ for (auto *e : lst)
+ genFIR(*e);
+ };
- // Sink prologue
- emit(prologue);
+ sink(prologue);
// Lower innermost loop body, skipping sunk
for (Fortran::lower::pft::Evaluation &e : innermostLoopEval->getNestedEvaluations())
if (!sunk.contains(&e)) genFIR(e);
- // Sink epilogue
- emit(epilogue);
+ sink(epilogue);
} else {
// Normal lowering
for (Fortran::lower::pft::Evaluation &e : curEval->getNestedEvaluations())
diff --git a/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90 b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
new file mode 100644
index 0000000000000..ca932c1b159ba
--- /dev/null
+++ b/flang/test/Lower/OpenACC/acc-loop-collapse-force-lowering.f90
@@ -0,0 +1,41 @@
+! RUN: bbc -fopenacc -emit-hlfir %s -o - | FileCheck %s
+
+! Verify collapse(force:2) sinks prologue (between loops) and epilogue (after inner loop)
+! into the acc.loop region body.
+
+subroutine collapse_force_sink(n, m)
+ integer, intent(in) :: n, m
+ real, dimension(n,m) :: a
+ real, dimension(n) :: bb, cc
+ integer :: i, j
+
+ !$acc parallel loop collapse(force:2)
+ do i = 1, n
+ bb(i) = 4.2 ! prologue (between loops)
+ do j = 1, m
+ a(i,j) = a(i,j) + 2.0
+ end do
+ cc(i) = 7.3 ! epilogue (after inner loop)
+ end do
+ !$acc end parallel loop
+end subroutine
+
+! CHECK: func.func @_QPcollapse_force_sink(
+! CHECK: acc.parallel
+! Ensure outer acc.loop is combined(parallel)
+! CHECK: acc.loop combined(parallel)
+! Prologue: constant 4.2 and an assign before inner loop
+! CHECK: arith.constant 4.200000e+00
+! CHECK: hlfir.assign
+! Inner loop and its body include 2.0 add and an assign
+! CHECK: acc.loop
+! CHECK: arith.constant 2.000000e+00
+! CHECK: arith.addf
+! CHECK: hlfir.assign
+! Epilogue: constant 7.3 and an assign after inner loop
+! CHECK: arith.constant 7.300000e+00
+! CHECK: hlfir.assign
+! And the outer acc.loop has collapse = [2]
+! CHECK: } attributes {collapse = [2]
+
+
>From fa52cbb69007c2ecf4d1818289fb0cb93a7b1e94 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 8 Oct 2025 12:18:17 -0700
Subject: [PATCH 06/10] cleanup code
---
flang/lib/Lower/OpenACC.cpp | 20 +++++++++-----------
1 file changed, 9 insertions(+), 11 deletions(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index e24e784895fe8..c376609ee1b5b 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2145,25 +2145,23 @@ static void processDoLoopBounds(
Fortran::parser::FindSourceLocation(outerDoConstruct)));
} else {
// Safely locate the next inner DoConstruct within this eval.
- const Fortran::parser::DoConstruct *doCons = nullptr;
+ const Fortran::parser::DoConstruct *innerDo = nullptr;
if (crtEval && crtEval->hasNestedEvaluations()) {
- for (Fortran::lower::pft::Evaluation &child :
- crtEval->getNestedEvaluations()) {
- if (auto *cand = child.getIf<Fortran::parser::DoConstruct>()) {
- doCons = cand;
+ for (Fortran::lower::pft::Evaluation &child : crtEval->getNestedEvaluations()) {
+ if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) {
+ innerDo = stmt;
// Prepare to descend for the next iteration
crtEval = &child;
break;
}
}
}
- if (!doCons) {
- // No deeper loop; stop collecting collapsed bounds.
- break;
- }
- loopControl = &*doCons->GetLoopControl();
+ if (!innerDo)
+ break; // No deeper loop; stop collecting collapsed bounds.
+
+ loopControl = &*innerDo->GetLoopControl();
locs.push_back(converter.genLocation(
- Fortran::parser::FindSourceLocation(*doCons)));
+ Fortran::parser::FindSourceLocation(*innerDo)));
}
const Fortran::parser::LoopControl::Bounds *bounds =
>From e2141ff297091d0f3da86dd70425c5f6d542f546 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 8 Oct 2025 12:38:53 -0700
Subject: [PATCH 07/10] refactor code to parse in OpenACC.cpp
---
flang/include/flang/Lower/OpenACC.h | 3 +++
flang/lib/Lower/Bridge.cpp | 6 ++++--
flang/lib/Lower/OpenACC.cpp | 14 ++++++++++++++
3 files changed, 21 insertions(+), 2 deletions(-)
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index 4622dbc8ccf64..f6ec3658eff30 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -122,6 +122,9 @@ void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
/// clause.
uint64_t getLoopCountForCollapseAndTile(const Fortran::parser::AccClauseList &);
+/// Returns only the collapse(N) depth (defaults to 1 when absent).
+uint64_t getLoopCountForCollapse(const Fortran::parser::AccClauseList &);
+
/// Checks whether the current insertion point is inside OpenACC loop.
bool isInOpenACCLoop(fir::FirOpBuilder &);
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 3d331cdad3d43..8482bba4ecbf8 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3222,7 +3222,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
- std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
+ collapseDepth = Fortran::lower::getLoopCountForCollapse(clauseList);
+ std::tie(collapseForce, std::ignore) = parseCollapse(clauseList);
} else if (accCombined) {
const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
std::get<Fortran::parser::AccBeginCombinedDirective>(
@@ -3230,7 +3231,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
- std::tie(collapseForce, collapseDepth) = parseCollapse(clauseList);
+ collapseDepth = Fortran::lower::getLoopCountForCollapse(clauseList);
+ std::tie(collapseForce, std::ignore) = parseCollapse(clauseList);
}
if (curEval->lowerAsStructured()) {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index c376609ee1b5b..90edc102e13a0 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -4889,6 +4889,20 @@ uint64_t Fortran::lower::getLoopCountForCollapseAndTile(
return collapseLoopCount;
}
+uint64_t Fortran::lower::getLoopCountForCollapse(
+ const Fortran::parser::AccClauseList &clauseList) {
+ for (const Fortran::parser::AccClause &clause : clauseList.v) {
+ if (const auto *collapseClause =
+ std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
+ const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
+ const auto &collapseValue =
+ std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
+ return *Fortran::semantics::GetIntValue(collapseValue);
+ }
+ }
+ return 1;
+}
+
/// Create an ACC loop operation for a DO construct when inside ACC compute
/// constructs This serves as a bridge between regular DO construct handling and
/// ACC loop creation
>From c27d440f5510bb441aaa23ec96b4218142d53a73 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 8 Oct 2025 12:54:15 -0700
Subject: [PATCH 08/10] use collapseLoopCount in calculating tile loopcount
---
flang/lib/Lower/OpenACC.cpp | 15 +++------------
1 file changed, 3 insertions(+), 12 deletions(-)
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index 90edc102e13a0..fb8c882cb30fa 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -4868,25 +4868,16 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
uint64_t Fortran::lower::getLoopCountForCollapseAndTile(
const Fortran::parser::AccClauseList &clauseList) {
- uint64_t collapseLoopCount = 1;
+ uint64_t collapseLoopCount = getLoopCountForCollapse(clauseList);
uint64_t tileLoopCount = 1;
for (const Fortran::parser::AccClause &clause : clauseList.v) {
- if (const auto *collapseClause =
- std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
- const parser::AccCollapseArg &arg = collapseClause->v;
- const auto &collapseValue{std::get<parser::ScalarIntConstantExpr>(arg.t)};
- collapseLoopCount = *Fortran::semantics::GetIntValue(collapseValue);
- }
if (const auto *tileClause =
std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
const parser::AccTileExprList &tileExprList = tileClause->v;
- const std::list<parser::AccTileExpr> &listTileExpr = tileExprList.v;
- tileLoopCount = listTileExpr.size();
+ tileLoopCount = tileExprList.v.size();
}
}
- if (tileLoopCount > collapseLoopCount)
- return tileLoopCount;
- return collapseLoopCount;
+ return tileLoopCount > collapseLoopCount ? tileLoopCount : collapseLoopCount;
}
uint64_t Fortran::lower::getLoopCountForCollapse(
>From 8c11d6fb3f73c02ceb5c7c035071e88a09ce4131 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 8 Oct 2025 16:06:17 -0700
Subject: [PATCH 09/10] format
---
flang/lib/Lower/Bridge.cpp | 23 ++++++++++++++---------
flang/lib/Lower/OpenACC.cpp | 3 ++-
2 files changed, 16 insertions(+), 10 deletions(-)
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index 8482bba4ecbf8..e3f59d475b24c 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3195,7 +3195,7 @@ class FirConverter : public Fortran::lower::AbstractConverter {
uint64_t collapseDepth = 1;
uint64_t loopCount = 1;
auto parseCollapse = [&](const Fortran::parser::AccClauseList &cl)
- -> std::pair<bool, uint64_t> {
+ -> std::pair<bool, uint64_t> {
bool force = false;
uint64_t depth = 1;
for (const Fortran::parser::AccClause &clause : cl.v) {
@@ -3244,16 +3244,17 @@ class FirConverter : public Fortran::lower::AbstractConverter {
const bool isStructured = curEval && curEval->lowerAsStructured();
if (isStructured && collapseForce && collapseDepth > 1) {
- // force: collect prologue/epilogue for the first collapseDepth nested loops
- // and sink them into the innermost loop body at that depth
+ // force: collect prologue/epilogue for the first collapseDepth nested
+ // loops and sink them into the innermost loop body at that depth
llvm::SmallVector<Fortran::lower::pft::Evaluation *> prologue, epilogue;
Fortran::lower::pft::Evaluation *parent = &getEval();
Fortran::lower::pft::Evaluation *innermostLoopEval = nullptr;
for (uint64_t lvl = 0; lvl + 1 < collapseDepth; ++lvl) {
epilogue.clear();
auto &kids = parent->getNestedEvaluations();
- // Collect all non-loop statements before the next inner loop as prologue,
- // then mark remaining siblings as epilogue and descend into the inner loop.
+ // Collect all non-loop statements before the next inner loop as
+ // prologue, then mark remaining siblings as epilogue and descend into
+ // the inner loop.
Fortran::lower::pft::Evaluation *childLoop = nullptr;
for (auto it = kids.begin(); it != kids.end(); ++it) {
if (it->getIf<Fortran::parser::DoConstruct>()) {
@@ -3273,8 +3274,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
// Track sunk evaluations (avoid double-lowering)
llvm::SmallPtrSet<const Fortran::lower::pft::Evaluation *, 16> sunk;
- for (auto *e : prologue) sunk.insert(e);
- for (auto *e : epilogue) sunk.insert(e);
+ for (auto *e : prologue)
+ sunk.insert(e);
+ for (auto *e : epilogue)
+ sunk.insert(e);
auto sink =
[&](llvm::SmallVector<Fortran::lower::pft::Evaluation *> &lst) {
@@ -3285,8 +3288,10 @@ class FirConverter : public Fortran::lower::AbstractConverter {
sink(prologue);
// Lower innermost loop body, skipping sunk
- for (Fortran::lower::pft::Evaluation &e : innermostLoopEval->getNestedEvaluations())
- if (!sunk.contains(&e)) genFIR(e);
+ for (Fortran::lower::pft::Evaluation &e :
+ innermostLoopEval->getNestedEvaluations())
+ if (!sunk.contains(&e))
+ genFIR(e);
sink(epilogue);
} else {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index fb8c882cb30fa..c38fd3a78c393 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -2147,7 +2147,8 @@ static void processDoLoopBounds(
// Safely locate the next inner DoConstruct within this eval.
const Fortran::parser::DoConstruct *innerDo = nullptr;
if (crtEval && crtEval->hasNestedEvaluations()) {
- for (Fortran::lower::pft::Evaluation &child : crtEval->getNestedEvaluations()) {
+ for (Fortran::lower::pft::Evaluation &child :
+ crtEval->getNestedEvaluations()) {
if (auto *stmt = child.getIf<Fortran::parser::DoConstruct>()) {
innerDo = stmt;
// Prepare to descend for the next iteration
>From 7cca8e4ee98c7f8f6c2df67784e40a0ab6c55dc5 Mon Sep 17 00:00:00 2001
From: Susan Tan <zujunt at nvidia.com>
Date: Wed, 8 Oct 2025 16:24:13 -0700
Subject: [PATCH 10/10] refactor
---
flang/include/flang/Lower/OpenACC.h | 6 ++++--
flang/lib/Lower/Bridge.cpp | 28 ++++------------------------
flang/lib/Lower/OpenACC.cpp | 12 ++++++++----
3 files changed, 16 insertions(+), 30 deletions(-)
diff --git a/flang/include/flang/Lower/OpenACC.h b/flang/include/flang/Lower/OpenACC.h
index f6ec3658eff30..69f1f5be753e6 100644
--- a/flang/include/flang/Lower/OpenACC.h
+++ b/flang/include/flang/Lower/OpenACC.h
@@ -122,8 +122,10 @@ void genOpenACCTerminator(fir::FirOpBuilder &, mlir::Operation *,
/// clause.
uint64_t getLoopCountForCollapseAndTile(const Fortran::parser::AccClauseList &);
-/// Returns only the collapse(N) depth (defaults to 1 when absent).
-uint64_t getLoopCountForCollapse(const Fortran::parser::AccClauseList &);
+/// Parse collapse clause and return {size, force}. If absent, returns
+/// {1,false}.
+std::pair<uint64_t, bool>
+getCollapseSizeAndForce(const Fortran::parser::AccClauseList &);
/// Checks whether the current insertion point is inside OpenACC loop.
bool isInOpenACCLoop(fir::FirOpBuilder &);
diff --git a/flang/lib/Lower/Bridge.cpp b/flang/lib/Lower/Bridge.cpp
index e3f59d475b24c..b41ebd7d15a00 100644
--- a/flang/lib/Lower/Bridge.cpp
+++ b/flang/lib/Lower/Bridge.cpp
@@ -3194,26 +3194,6 @@ class FirConverter : public Fortran::lower::AbstractConverter {
bool collapseForce = false;
uint64_t collapseDepth = 1;
uint64_t loopCount = 1;
- auto parseCollapse = [&](const Fortran::parser::AccClauseList &cl)
- -> std::pair<bool, uint64_t> {
- bool force = false;
- uint64_t depth = 1;
- for (const Fortran::parser::AccClause &clause : cl.v) {
- if (const auto *collapseClause =
- std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
- const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
- force = std::get<bool>(arg.t);
- const auto &intExpr =
- std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
- if (const auto *expr = Fortran::semantics::GetExpr(intExpr)) {
- if (auto v = Fortran::evaluate::ToInt64(*expr))
- depth = *v;
- }
- break;
- }
- }
- return {force, depth};
- };
if (accLoop || accCombined) {
if (accLoop) {
@@ -3222,8 +3202,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginLoopDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
- collapseDepth = Fortran::lower::getLoopCountForCollapse(clauseList);
- std::tie(collapseForce, std::ignore) = parseCollapse(clauseList);
+ std::tie(collapseDepth, collapseForce) =
+ Fortran::lower::getCollapseSizeAndForce(clauseList);
} else if (accCombined) {
const Fortran::parser::AccBeginCombinedDirective &beginCombinedDir =
std::get<Fortran::parser::AccBeginCombinedDirective>(
@@ -3231,8 +3211,8 @@ class FirConverter : public Fortran::lower::AbstractConverter {
const Fortran::parser::AccClauseList &clauseList =
std::get<Fortran::parser::AccClauseList>(beginCombinedDir.t);
loopCount = Fortran::lower::getLoopCountForCollapseAndTile(clauseList);
- collapseDepth = Fortran::lower::getLoopCountForCollapse(clauseList);
- std::tie(collapseForce, std::ignore) = parseCollapse(clauseList);
+ std::tie(collapseDepth, collapseForce) =
+ Fortran::lower::getCollapseSizeAndForce(clauseList);
}
if (curEval->lowerAsStructured()) {
diff --git a/flang/lib/Lower/OpenACC.cpp b/flang/lib/Lower/OpenACC.cpp
index c38fd3a78c393..0aed144fc5123 100644
--- a/flang/lib/Lower/OpenACC.cpp
+++ b/flang/lib/Lower/OpenACC.cpp
@@ -4869,7 +4869,7 @@ void Fortran::lower::genEarlyReturnInOpenACCLoop(fir::FirOpBuilder &builder,
uint64_t Fortran::lower::getLoopCountForCollapseAndTile(
const Fortran::parser::AccClauseList &clauseList) {
- uint64_t collapseLoopCount = getLoopCountForCollapse(clauseList);
+ uint64_t collapseLoopCount = getCollapseSizeAndForce(clauseList).first;
uint64_t tileLoopCount = 1;
for (const Fortran::parser::AccClause &clause : clauseList.v) {
if (const auto *tileClause =
@@ -4881,18 +4881,22 @@ uint64_t Fortran::lower::getLoopCountForCollapseAndTile(
return tileLoopCount > collapseLoopCount ? tileLoopCount : collapseLoopCount;
}
-uint64_t Fortran::lower::getLoopCountForCollapse(
+std::pair<uint64_t, bool> Fortran::lower::getCollapseSizeAndForce(
const Fortran::parser::AccClauseList &clauseList) {
+ uint64_t size = 1;
+ bool force = false;
for (const Fortran::parser::AccClause &clause : clauseList.v) {
if (const auto *collapseClause =
std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
+ force = std::get<bool>(arg.t);
const auto &collapseValue =
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
- return *Fortran::semantics::GetIntValue(collapseValue);
+ size = *Fortran::semantics::GetIntValue(collapseValue);
+ break;
}
}
- return 1;
+ return {size, force};
}
/// Create an ACC loop operation for a DO construct when inside ACC compute
More information about the flang-commits
mailing list