[flang] [llvm] [mlir] [openmp] [MLIR][OpenMP] Add omp.fuse operation (PR #168898)
via llvm-commits
llvm-commits at lists.llvm.org
Thu Nov 20 08:21:46 PST 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-openmp
Author: Ferran Toda (NouTimbaler)
<details>
<summary>Changes</summary>
This patch is a follow-up from #<!-- -->161213 and adds the omp.fuse loop transformation for the OpenMP dialect. Used for lowering a `!$omp fuse` in Flang.
Added Lowering and end2end tests.
---
Patch is 124.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168898.diff
41 Files Affected:
- (modified) flang/include/flang/Parser/openmp-utils.h (+3)
- (modified) flang/include/flang/Semantics/openmp-directive-sets.h (+7)
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+1)
- (modified) flang/lib/Lower/OpenMP/Clauses.cpp (+4-1)
- (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+2-1)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+83-23)
- (modified) flang/lib/Lower/OpenMP/Utils.cpp (+17-11)
- (modified) flang/lib/Lower/OpenMP/Utils.h (+4-2)
- (modified) flang/lib/Parser/openmp-parsers.cpp (+1)
- (modified) flang/lib/Parser/openmp-utils.cpp (+17)
- (modified) flang/lib/Semantics/canonicalize-omp.cpp (+72-45)
- (modified) flang/lib/Semantics/check-omp-loop.cpp (+106-33)
- (modified) flang/lib/Semantics/check-omp-structure.cpp (+5-3)
- (modified) flang/lib/Semantics/check-omp-structure.h (+2)
- (modified) flang/lib/Semantics/resolve-directives.cpp (+109-101)
- (modified) flang/lib/Semantics/rewrite-parse-tree.cpp (+24-15)
- (added) flang/test/Lower/OpenMP/fuse01.f90 (+93)
- (added) flang/test/Lower/OpenMP/fuse02.f90 (+123)
- (added) flang/test/Parser/OpenMP/fail-looprange.f90 (+11)
- (added) flang/test/Parser/OpenMP/fuse-looprange.f90 (+38)
- (added) flang/test/Parser/OpenMP/fuse01.f90 (+28)
- (added) flang/test/Parser/OpenMP/fuse02.f90 (+97)
- (added) flang/test/Parser/OpenMP/loop-transformation-construct04.f90 (+80)
- (added) flang/test/Parser/OpenMP/loop-transformation-construct05.f90 (+90)
- (added) flang/test/Semantics/OpenMP/loop-transformation-clauses01.f90 (+66)
- (modified) flang/test/Semantics/OpenMP/loop-transformation-construct01.f90 (+2-2)
- (added) flang/test/Semantics/OpenMP/loop-transformation-construct02.f90 (+93)
- (added) flang/test/Semantics/OpenMP/loop-transformation-construct03.f90 (+39)
- (added) flang/test/Semantics/OpenMP/loop-transformation-construct04.f90 (+47)
- (modified) flang/test/Semantics/OpenMP/tile02.f90 (+1-1)
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+53)
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+111)
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+34)
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+68)
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+54)
- (added) mlir/test/Dialect/OpenMP/cli-fuse.mlir (+114)
- (added) mlir/test/Dialect/OpenMP/invalid-fuse.mlir (+100)
- (added) mlir/test/Target/LLVMIR/openmp-cli-fuse01.mlir (+100)
- (added) mlir/test/Target/LLVMIR/openmp-cli-fuse02.mlir (+140)
- (added) openmp/runtime/test/transform/fuse/do-looprange.f90 (+60)
- (added) openmp/runtime/test/transform/fuse/do.f90 (+52)
``````````diff
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index 36556f8dd7f4a..7396e57144b90 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -123,6 +123,9 @@ template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
const OpenMPDeclarativeConstruct *GetOmp(const DeclarationConstruct &x);
const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);
+const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
+const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);
+
const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
template <typename T>
diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 01e8481e05721..609a7be700c28 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -275,10 +275,17 @@ static const OmpDirectiveSet loopConstructSet{
Directive::OMPD_teams_distribute_parallel_do_simd,
Directive::OMPD_teams_distribute_simd,
Directive::OMPD_teams_loop,
+ Directive::OMPD_fuse,
Directive::OMPD_tile,
Directive::OMPD_unroll,
};
+static const OmpDirectiveSet loopTransformationSet{
+ Directive::OMPD_tile,
+ Directive::OMPD_unroll,
+ Directive::OMPD_fuse,
+};
+
static const OmpDirectiveSet nonPartialVarSet{
Directive::OMPD_allocate,
Directive::OMPD_allocators,
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 4a392381287d5..ab3a174c7ad69 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -279,6 +279,7 @@ bool ClauseProcessor::processCollapse(
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
+ eval.getFirstNestedEvaluation(),
clauses, loopResult, iv);
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
collapseResult.collapseNumLoops = firOpBuilder.getI64IntegerAttr(numCollapse);
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b1a3c3d3c5439..f2defc62dce91 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1063,7 +1063,10 @@ Link make(const parser::OmpClause::Link &inp,
LoopRange make(const parser::OmpClause::Looprange &inp,
semantics::SemanticsContext &semaCtx) {
- llvm_unreachable("Unimplemented: looprange");
+ auto &t0 = std::get<0>(inp.v.t);
+ auto &t1 = std::get<1>(inp.v.t);
+ return LoopRange{{/*First*/ makeExpr(t0, semaCtx),
+ /*Count*/ makeExpr(t1, semaCtx)}};
}
Map make(const parser::OmpClause::Map &inp,
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 83c2eda0a2dc7..da9480123513f 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -347,7 +347,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
mlir::omp::LoopRelatedClauseOps result;
llvm::SmallVector<const semantics::Symbol *> iv;
collectLoopRelatedInfo(converter, converter.getCurrentLocation(), eval,
- clauses, result, iv);
+ eval.getFirstNestedEvaluation(), clauses, result,
+ iv);
// Update the original variable just before exiting the worksharing
// loop. Conversion as follows:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c6487349c4056..2d981f421a4ae 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1982,9 +1982,9 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
static void genCanonicalLoopNest(
lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
- mlir::Location loc, const ConstructQueue &queue,
- ConstructQueue::const_iterator item, size_t numLoops,
- llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
+ lower::pft::Evaluation &nestedEval, mlir::Location loc,
+ const ConstructQueue &queue, ConstructQueue::const_iterator item,
+ size_t numLoops, llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
assert(loops.empty() && "Expecting empty list to fill");
assert(numLoops >= 1 && "Expecting at least one loop");
@@ -1992,7 +1992,8 @@ static void genCanonicalLoopNest(
mlir::omp::LoopRelatedClauseOps loopInfo;
llvm::SmallVector<const semantics::Symbol *, 3> ivs;
- collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs);
+ collectLoopRelatedInfo(converter, loc, eval, nestedEval, numLoops, loopInfo,
+ ivs);
assert(ivs.size() == numLoops &&
"Expected to parse as many loop variables as there are loops");
@@ -2014,7 +2015,7 @@ static void genCanonicalLoopNest(
// Step 1: Loop prologues
// Computing the trip count must happen before entering the outermost loop
- lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation();
+ lower::pft::Evaluation *innermostEval = &nestedEval;
for ([[maybe_unused]] auto iv : ivs) {
if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
// OpenMP specifies DO CONCURRENT only with the `!omp loop` construct.
@@ -2186,7 +2187,8 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops;
canonLoops.reserve(numLoops);
- genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item,
+ genCanonicalLoopNest(converter, symTable, semaCtx, eval,
+ eval.getFirstNestedEvaluation(), loc, queue, item,
numLoops, canonLoops);
assert((canonLoops.size() == numLoops) &&
"Expecting the predetermined number of loops");
@@ -2217,6 +2219,58 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
sizesClause.sizes);
}
+static void genFuseOp(Fortran::lower::AbstractConverter &converter,
+ Fortran::lower::SymMap &symTable,
+ lower::StatementContext &stmtCtx,
+ Fortran::semantics::SemanticsContext &semaCtx,
+ Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+ const ConstructQueue &queue,
+ ConstructQueue::const_iterator item) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+ int32_t first = 0;
+ int32_t count = 0;
+ auto iter = llvm::find_if(item->clauses, [](const Clause &clause) {
+ return clause.id == llvm::omp::Clause::OMPC_looprange;
+ });
+ if (iter != item->clauses.end()) {
+ const auto &looprange = std::get<clause::LoopRange>(iter->u);
+ first = evaluate::ToInt64(std::get<0>(looprange.t)).value();
+ count = evaluate::ToInt64(std::get<1>(looprange.t)).value();
+ }
+
+ llvm::SmallVector<mlir::Value> applyees;
+ for (auto &child : eval.getNestedEvaluations()) {
+ // Skip OmpEndLoopDirective
+ if (&child == &eval.getLastNestedEvaluation())
+ break;
+
+ // Emit the associated loop
+ llvm::SmallVector<mlir::omp::CanonicalLoopOp> canonLoops;
+ genCanonicalLoopNest(converter, symTable, semaCtx, eval, child, loc, queue,
+ item, 1, canonLoops);
+
+ auto cli = llvm::getSingleElement(canonLoops).getCli();
+ applyees.push_back(cli);
+ }
+ // One generated loop + one for each loop not inside the specified looprange
+ // if present
+ llvm::SmallVector<mlir::Value> generatees;
+ int64_t numGeneratees = count == 0 ? 1 : applyees.size() - count + 1;
+ for (int i = 0; i < numGeneratees; i++) {
+ auto fusedCLI = mlir::omp::NewCliOp::create(firOpBuilder, loc);
+ generatees.push_back(fusedCLI);
+ }
+ auto op = mlir::omp::FuseOp::create(firOpBuilder, loc, generatees, applyees);
+
+ if (count != 0) {
+ mlir::IntegerAttr firstAttr = firOpBuilder.getI32IntegerAttr(first);
+ mlir::IntegerAttr countAttr = firOpBuilder.getI32IntegerAttr(count);
+ op->setAttr("first", firstAttr);
+ op->setAttr("count", countAttr);
+ }
+}
+
static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
Fortran::lower::SymMap &symTable,
lower::StatementContext &stmtCtx,
@@ -2233,7 +2287,8 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
// Emit the associated loop
llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops;
- genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1,
+ genCanonicalLoopNest(converter, symTable, semaCtx, eval,
+ eval.getFirstNestedEvaluation(), loc, queue, item, 1,
canonLoops);
llvm::SmallVector<mlir::Value, 1> applyees;
@@ -3507,6 +3562,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
case llvm::omp::Directive::OMPD_tile:
genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
+ case llvm::omp::Directive::OMPD_fuse:
+ genFuseOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
+ break;
case llvm::omp::Directive::OMPD_unroll:
genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
break;
@@ -3962,22 +4020,24 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
mlir::Location currentLocation = converter.genLocation(beginSpec.source);
- if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
- loopConstruct.GetNestedConstruct()) {
- llvm::omp::Directive nestedDirective =
- parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
- switch (nestedDirective) {
- case llvm::omp::Directive::OMPD_tile:
- // Skip OMPD_tile since the tile sizes will be retrieved when
- // generating the omp.loop_nest op.
- break;
- default: {
- unsigned version = semaCtx.langOptions().OpenMPVersion;
- TODO(currentLocation,
- "Applying a loop-associated on the loop generated by the " +
- llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
- " construct");
- }
+ for (auto &construct : std::get<parser::Block>(loopConstruct.t)) {
+ if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
+ parser::omp::GetOmpLoop(construct)) {
+ llvm::omp::Directive nestedDirective =
+ parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
+ switch (nestedDirective) {
+ case llvm::omp::Directive::OMPD_tile:
+ // Skip OMPD_tile since the tile sizes will be retrieved when
+ // generating the omp.loop_nest op.
+ break;
+ default: {
+ unsigned version = semaCtx.langOptions().OpenMPVersion;
+ TODO(currentLocation,
+ "Applying a loop-associated on the loop generated by the " +
+ llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
+ " construct");
+ }
+ }
}
}
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 7d7a4869ab3a6..913e4d1e69500 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -812,13 +812,14 @@ void collectTileSizesFromOpenMPConstruct(
int64_t collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
- lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+ lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+ const omp::List<omp::Clause> &clauses,
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
int64_t numCollapse = 1;
// Collect the loops to collapse.
- lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+ lower::pft::Evaluation *doConstructEval = &nestedEval;
if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
}
@@ -830,21 +831,21 @@ int64_t collectLoopRelatedInfo(
numCollapse = collapseValue;
}
- collectLoopRelatedInfo(converter, currentLocation, eval, numCollapse, result,
- iv);
+ collectLoopRelatedInfo(converter, currentLocation, eval, nestedEval,
+ numCollapse, result, iv);
return numCollapse;
}
void collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
- lower::pft::Evaluation &eval, int64_t numCollapse,
- mlir::omp::LoopRelatedClauseOps &result,
+ lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+ int64_t numCollapse, mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
// Collect the loops to collapse.
- lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+ lower::pft::Evaluation *doConstructEval = &nestedEval;
if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
}
@@ -852,10 +853,15 @@ void collectLoopRelatedInfo(
// Collect sizes from tile directive if present.
std::int64_t sizesLengthValue = 0l;
if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
- processTileSizesFromOpenMPConstruct(
- ompCons, [&](const parser::OmpClause::Sizes *tclause) {
- sizesLengthValue = tclause->v.size();
- });
+ if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
+ const parser::OmpDirectiveSpecification &beginSpec{ompLoop->BeginDir()};
+ if (beginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
+ processTileSizesFromOpenMPConstruct(
+ ompCons, [&](const parser::OmpClause::Sizes *tclause) {
+ sizesLengthValue = tclause->v.size();
+ });
+ }
+ }
}
std::int64_t collapseValue = std::max(numCollapse, sizesLengthValue);
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 2960b663b08b2..886a5c1835f7e 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -169,13 +169,15 @@ void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
int64_t collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
- lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+ lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+ const omp::List<omp::Clause> &clauses,
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
void collectLoopRelatedInfo(
lower::AbstractConverter &converter, mlir::Location currentLocation,
- lower::pft::Evaluation &eval, std::int64_t collapseValue,
+ lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+ std::int64_t collapseValue,
// const omp::List<omp::Clause> &clauses,
mlir::omp::LoopRelatedClauseOps &result,
llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index e2da60ed19de8..231eea8841d4b 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -2260,6 +2260,7 @@ static constexpr DirectiveSet GetLoopDirectives() {
unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
unsigned(Directive::OMPD_teams_distribute_simd),
unsigned(Directive::OMPD_teams_loop),
+ unsigned(Directive::OMPD_fuse),
unsigned(Directive::OMPD_tile),
unsigned(Directive::OMPD_unroll),
};
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index 2424828293c73..dfe8dbdd5ac9e 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -41,6 +41,23 @@ const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x) {
return nullptr;
}
+const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x) {
+ if (auto *construct{GetOmp(x)}) {
+ if (auto *omp{std::get_if<OpenMPLoopConstruct>(&construct->u)}) {
+ return omp;
+ }
+ }
+ return nullptr;
+}
+const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x) {
+ if (auto *y{std::get_if<ExecutableConstruct>(&x.u)}) {
+ if (auto *z{std::get_if<common::Indirection<DoConstruct>>(&y->u)}) {
+ return &z->value();
+ }
+ }
+ return nullptr;
+}
+
const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
// Clauses with OmpObjectList as its data member
using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
diff --git a/flang/lib/Semantics/canonicalize-omp.cpp b/flang/lib/Semantics/canonicalize-omp.cpp
index 0cec1969e0978..f7c53d6d8f4c4 100644
--- a/flang/lib/Semantics/canonicalize-omp.cpp
+++ b/flang/lib/Semantics/canonicalize-omp.cpp
@@ -9,6 +9,7 @@
#include "canonicalize-omp.h"
#include "flang/Parser/parse-tree-visitor.h"
#include "flang/Parser/parse-tree.h"
+#include "flang/Semantics/openmp-directive-sets.h"
#include "flang/Semantics/semantics.h"
// After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
@@ -136,20 +137,30 @@ class CanonicalizationOfOmp {
"A DO loop must follow the %s directive"_err_en_US,
parser::ToUpperCaseLetters(dirName.source.ToString()));
};
- auto tileUnrollError = [](const parser::OmpDirectiveName &dirName,
- parser::Messages &messages) {
+ auto transformUnrollError = [](const parser::OmpDirectiveName &dirName,
+ parser::Messages &messages) {
messages.Say(dirName.source,
- "If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
+ "If a loop construct has been fully unrolled, it cannot then be further transformed"_err_en_US,
parser::ToUpperCaseLetters(dirName.source.ToString()));
};
+ auto missingEndFuse = [](auto &dir, auto &messages) {
+ messages.Say(dir.source,
+ "The %s construct requires the END FUSE directive"_err_en_US,
+ parser::ToUpperCaseLetters(dir.source.ToString()));
+ };
+
+ bool endFuseNeeded = beginName.v == llvm::omp::Directive::OMPD_fuse;
auto &body{std::get<parser::Block>(x.t)};
nextIt = it;
- while (++nextIt != block.end()) {
+ nextIt++;
+ while (nextIt != block.end()) {
// Ignore compiler directives.
- if (GetConstructIf<parser::CompilerDirective>(*nextIt))
+ if (GetConstructIf<parser::CompilerDirective>(*nextIt)) {
+ nextIt++;
continue;
+ }
if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
if (doCons->GetLoopControl()) {
@@ -160,9 +171,12 @@ class CanonicalizationOfOmp {
if (nextIt != block.end()) {
if (auto *endDir{
GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
- std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
- std::move(*endDir);
- nextIt = block.erase(nextIt);
+ auto &endDirName = endDir->DirName();
+ if (endDirName.v != llvm::omp::Directive::OMPD_fuse) {
+ std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
+ std::move(*endDir);
+ nextIt = block.erase(nextIt);
+ }
}
}
} else {
@@ -172,50 +186,45 @@ class CanonicalizationOfOmp {
}
} else if (auto *ompLoopCons{
GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
- // We should allow UNROLL and TILE constructs to be inserted between an
- // OpenMP Loop Construct and the DO loop itself
+ // We should allow loop transformation constructs to be inserted between
+ ...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/168898
More information about the llvm-commits
mailing list