[flang] [llvm] [mlir] [openmp] [MLIR][OpenMP] Add omp.fuse operation (PR #168898)

via llvm-commits llvm-commits at lists.llvm.org
Thu Nov 20 08:21:45 PST 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-llvm

Author: Ferran Toda (NouTimbaler)

<details>
<summary>Changes</summary>

This patch is a follow-up from #<!-- -->161213 and adds the omp.fuse loop transformation for the OpenMP dialect. Used for lowering a `!$omp fuse` in Flang.
Added Lowering and end2end tests.

---

Patch is 124.30 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/168898.diff


41 Files Affected:

- (modified) flang/include/flang/Parser/openmp-utils.h (+3) 
- (modified) flang/include/flang/Semantics/openmp-directive-sets.h (+7) 
- (modified) flang/lib/Lower/OpenMP/ClauseProcessor.cpp (+1) 
- (modified) flang/lib/Lower/OpenMP/Clauses.cpp (+4-1) 
- (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+2-1) 
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+83-23) 
- (modified) flang/lib/Lower/OpenMP/Utils.cpp (+17-11) 
- (modified) flang/lib/Lower/OpenMP/Utils.h (+4-2) 
- (modified) flang/lib/Parser/openmp-parsers.cpp (+1) 
- (modified) flang/lib/Parser/openmp-utils.cpp (+17) 
- (modified) flang/lib/Semantics/canonicalize-omp.cpp (+72-45) 
- (modified) flang/lib/Semantics/check-omp-loop.cpp (+106-33) 
- (modified) flang/lib/Semantics/check-omp-structure.cpp (+5-3) 
- (modified) flang/lib/Semantics/check-omp-structure.h (+2) 
- (modified) flang/lib/Semantics/resolve-directives.cpp (+109-101) 
- (modified) flang/lib/Semantics/rewrite-parse-tree.cpp (+24-15) 
- (added) flang/test/Lower/OpenMP/fuse01.f90 (+93) 
- (added) flang/test/Lower/OpenMP/fuse02.f90 (+123) 
- (added) flang/test/Parser/OpenMP/fail-looprange.f90 (+11) 
- (added) flang/test/Parser/OpenMP/fuse-looprange.f90 (+38) 
- (added) flang/test/Parser/OpenMP/fuse01.f90 (+28) 
- (added) flang/test/Parser/OpenMP/fuse02.f90 (+97) 
- (added) flang/test/Parser/OpenMP/loop-transformation-construct04.f90 (+80) 
- (added) flang/test/Parser/OpenMP/loop-transformation-construct05.f90 (+90) 
- (added) flang/test/Semantics/OpenMP/loop-transformation-clauses01.f90 (+66) 
- (modified) flang/test/Semantics/OpenMP/loop-transformation-construct01.f90 (+2-2) 
- (added) flang/test/Semantics/OpenMP/loop-transformation-construct02.f90 (+93) 
- (added) flang/test/Semantics/OpenMP/loop-transformation-construct03.f90 (+39) 
- (added) flang/test/Semantics/OpenMP/loop-transformation-construct04.f90 (+47) 
- (modified) flang/test/Semantics/OpenMP/tile02.f90 (+1-1) 
- (modified) llvm/include/llvm/Frontend/OpenMP/OMPIRBuilder.h (+53) 
- (modified) llvm/lib/Frontend/OpenMP/OMPIRBuilder.cpp (+111) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+34) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+68) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+54) 
- (added) mlir/test/Dialect/OpenMP/cli-fuse.mlir (+114) 
- (added) mlir/test/Dialect/OpenMP/invalid-fuse.mlir (+100) 
- (added) mlir/test/Target/LLVMIR/openmp-cli-fuse01.mlir (+100) 
- (added) mlir/test/Target/LLVMIR/openmp-cli-fuse02.mlir (+140) 
- (added) openmp/runtime/test/transform/fuse/do-looprange.f90 (+60) 
- (added) openmp/runtime/test/transform/fuse/do.f90 (+52) 


``````````diff
diff --git a/flang/include/flang/Parser/openmp-utils.h b/flang/include/flang/Parser/openmp-utils.h
index 36556f8dd7f4a..7396e57144b90 100644
--- a/flang/include/flang/Parser/openmp-utils.h
+++ b/flang/include/flang/Parser/openmp-utils.h
@@ -123,6 +123,9 @@ template <typename T> OmpDirectiveName GetOmpDirectiveName(const T &x) {
 const OpenMPDeclarativeConstruct *GetOmp(const DeclarationConstruct &x);
 const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x);
 
+const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x);
+const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x);
+
 const OmpObjectList *GetOmpObjectList(const OmpClause &clause);
 
 template <typename T>
diff --git a/flang/include/flang/Semantics/openmp-directive-sets.h b/flang/include/flang/Semantics/openmp-directive-sets.h
index 01e8481e05721..609a7be700c28 100644
--- a/flang/include/flang/Semantics/openmp-directive-sets.h
+++ b/flang/include/flang/Semantics/openmp-directive-sets.h
@@ -275,10 +275,17 @@ static const OmpDirectiveSet loopConstructSet{
     Directive::OMPD_teams_distribute_parallel_do_simd,
     Directive::OMPD_teams_distribute_simd,
     Directive::OMPD_teams_loop,
+    Directive::OMPD_fuse,
     Directive::OMPD_tile,
     Directive::OMPD_unroll,
 };
 
+static const OmpDirectiveSet loopTransformationSet{
+    Directive::OMPD_tile,
+    Directive::OMPD_unroll,
+    Directive::OMPD_fuse,
+};
+
 static const OmpDirectiveSet nonPartialVarSet{
     Directive::OMPD_allocate,
     Directive::OMPD_allocators,
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index 4a392381287d5..ab3a174c7ad69 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -279,6 +279,7 @@ bool ClauseProcessor::processCollapse(
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) const {
 
   int64_t numCollapse = collectLoopRelatedInfo(converter, currentLocation, eval,
+                                               eval.getFirstNestedEvaluation(),
                                                clauses, loopResult, iv);
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   collapseResult.collapseNumLoops = firOpBuilder.getI64IntegerAttr(numCollapse);
diff --git a/flang/lib/Lower/OpenMP/Clauses.cpp b/flang/lib/Lower/OpenMP/Clauses.cpp
index b1a3c3d3c5439..f2defc62dce91 100644
--- a/flang/lib/Lower/OpenMP/Clauses.cpp
+++ b/flang/lib/Lower/OpenMP/Clauses.cpp
@@ -1063,7 +1063,10 @@ Link make(const parser::OmpClause::Link &inp,
 
 LoopRange make(const parser::OmpClause::Looprange &inp,
                semantics::SemanticsContext &semaCtx) {
-  llvm_unreachable("Unimplemented: looprange");
+  auto &t0 = std::get<0>(inp.v.t);
+  auto &t1 = std::get<1>(inp.v.t);
+  return LoopRange{{/*First*/ makeExpr(t0, semaCtx),
+                    /*Count*/ makeExpr(t1, semaCtx)}};
 }
 
 Map make(const parser::OmpClause::Map &inp,
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index 83c2eda0a2dc7..da9480123513f 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -347,7 +347,8 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
     mlir::omp::LoopRelatedClauseOps result;
     llvm::SmallVector<const semantics::Symbol *> iv;
     collectLoopRelatedInfo(converter, converter.getCurrentLocation(), eval,
-                           clauses, result, iv);
+                           eval.getFirstNestedEvaluation(), clauses, result,
+                           iv);
 
     // Update the original variable just before exiting the worksharing
     // loop. Conversion as follows:
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index c6487349c4056..2d981f421a4ae 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1982,9 +1982,9 @@ genLoopOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
 static void genCanonicalLoopNest(
     lower::AbstractConverter &converter, lower::SymMap &symTable,
     semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
-    mlir::Location loc, const ConstructQueue &queue,
-    ConstructQueue::const_iterator item, size_t numLoops,
-    llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
+    lower::pft::Evaluation &nestedEval, mlir::Location loc,
+    const ConstructQueue &queue, ConstructQueue::const_iterator item,
+    size_t numLoops, llvm::SmallVectorImpl<mlir::omp::CanonicalLoopOp> &loops) {
   assert(loops.empty() && "Expecting empty list to fill");
   assert(numLoops >= 1 && "Expecting at least one loop");
 
@@ -1992,7 +1992,8 @@ static void genCanonicalLoopNest(
 
   mlir::omp::LoopRelatedClauseOps loopInfo;
   llvm::SmallVector<const semantics::Symbol *, 3> ivs;
-  collectLoopRelatedInfo(converter, loc, eval, numLoops, loopInfo, ivs);
+  collectLoopRelatedInfo(converter, loc, eval, nestedEval, numLoops, loopInfo,
+                         ivs);
   assert(ivs.size() == numLoops &&
          "Expected to parse as many loop variables as there are loops");
 
@@ -2014,7 +2015,7 @@ static void genCanonicalLoopNest(
 
   // Step 1: Loop prologues
   // Computing the trip count must happen before entering the outermost loop
-  lower::pft::Evaluation *innermostEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *innermostEval = &nestedEval;
   for ([[maybe_unused]] auto iv : ivs) {
     if (innermostEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
       // OpenMP specifies DO CONCURRENT only with the `!omp loop` construct.
@@ -2186,7 +2187,8 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::omp::CanonicalLoopOp, 3> canonLoops;
   canonLoops.reserve(numLoops);
 
-  genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item,
+  genCanonicalLoopNest(converter, symTable, semaCtx, eval,
+                       eval.getFirstNestedEvaluation(), loc, queue, item,
                        numLoops, canonLoops);
   assert((canonLoops.size() == numLoops) &&
          "Expecting the predetermined number of loops");
@@ -2217,6 +2219,58 @@ static void genTileOp(Fortran::lower::AbstractConverter &converter,
                             sizesClause.sizes);
 }
 
+static void genFuseOp(Fortran::lower::AbstractConverter &converter,
+                      Fortran::lower::SymMap &symTable,
+                      lower::StatementContext &stmtCtx,
+                      Fortran::semantics::SemanticsContext &semaCtx,
+                      Fortran::lower::pft::Evaluation &eval, mlir::Location loc,
+                      const ConstructQueue &queue,
+                      ConstructQueue::const_iterator item) {
+  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+
+  int32_t first = 0;
+  int32_t count = 0;
+  auto iter = llvm::find_if(item->clauses, [](const Clause &clause) {
+    return clause.id == llvm::omp::Clause::OMPC_looprange;
+  });
+  if (iter != item->clauses.end()) {
+    const auto &looprange = std::get<clause::LoopRange>(iter->u);
+    first = evaluate::ToInt64(std::get<0>(looprange.t)).value();
+    count = evaluate::ToInt64(std::get<1>(looprange.t)).value();
+  }
+
+  llvm::SmallVector<mlir::Value> applyees;
+  for (auto &child : eval.getNestedEvaluations()) {
+    // Skip OmpEndLoopDirective
+    if (&child == &eval.getLastNestedEvaluation())
+      break;
+
+    // Emit the associated loop
+    llvm::SmallVector<mlir::omp::CanonicalLoopOp> canonLoops;
+    genCanonicalLoopNest(converter, symTable, semaCtx, eval, child, loc, queue,
+                         item, 1, canonLoops);
+
+    auto cli = llvm::getSingleElement(canonLoops).getCli();
+    applyees.push_back(cli);
+  }
+  // One generated loop + one for each loop not inside the specified looprange
+  // if present
+  llvm::SmallVector<mlir::Value> generatees;
+  int64_t numGeneratees = count == 0 ? 1 : applyees.size() - count + 1;
+  for (int i = 0; i < numGeneratees; i++) {
+    auto fusedCLI = mlir::omp::NewCliOp::create(firOpBuilder, loc);
+    generatees.push_back(fusedCLI);
+  }
+  auto op = mlir::omp::FuseOp::create(firOpBuilder, loc, generatees, applyees);
+
+  if (count != 0) {
+    mlir::IntegerAttr firstAttr = firOpBuilder.getI32IntegerAttr(first);
+    mlir::IntegerAttr countAttr = firOpBuilder.getI32IntegerAttr(count);
+    op->setAttr("first", firstAttr);
+    op->setAttr("count", countAttr);
+  }
+}
+
 static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
                         Fortran::lower::SymMap &symTable,
                         lower::StatementContext &stmtCtx,
@@ -2233,7 +2287,8 @@ static void genUnrollOp(Fortran::lower::AbstractConverter &converter,
 
   // Emit the associated loop
   llvm::SmallVector<mlir::omp::CanonicalLoopOp, 1> canonLoops;
-  genCanonicalLoopNest(converter, symTable, semaCtx, eval, loc, queue, item, 1,
+  genCanonicalLoopNest(converter, symTable, semaCtx, eval,
+                       eval.getFirstNestedEvaluation(), loc, queue, item, 1,
                        canonLoops);
 
   llvm::SmallVector<mlir::Value, 1> applyees;
@@ -3507,6 +3562,9 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
   case llvm::omp::Directive::OMPD_tile:
     genTileOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
     break;
+  case llvm::omp::Directive::OMPD_fuse:
+    genFuseOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
+    break;
   case llvm::omp::Directive::OMPD_unroll:
     genUnrollOp(converter, symTable, stmtCtx, semaCtx, eval, loc, queue, item);
     break;
@@ -3962,22 +4020,24 @@ static void genOMP(lower::AbstractConverter &converter, lower::SymMap &symTable,
 
   mlir::Location currentLocation = converter.genLocation(beginSpec.source);
 
-  if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
-          loopConstruct.GetNestedConstruct()) {
-    llvm::omp::Directive nestedDirective =
-        parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
-    switch (nestedDirective) {
-    case llvm::omp::Directive::OMPD_tile:
-      // Skip OMPD_tile since the tile sizes will be retrieved when
-      // generating the omp.loop_nest op.
-      break;
-    default: {
-      unsigned version = semaCtx.langOptions().OpenMPVersion;
-      TODO(currentLocation,
-           "Applying a loop-associated on the loop generated by the " +
-               llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
-               " construct");
-    }
+  for (auto &construct : std::get<parser::Block>(loopConstruct.t)) {
+    if (const parser::OpenMPLoopConstruct *ompNestedLoopCons =
+            parser::omp::GetOmpLoop(construct)) {
+      llvm::omp::Directive nestedDirective =
+          parser::omp::GetOmpDirectiveName(*ompNestedLoopCons).v;
+      switch (nestedDirective) {
+      case llvm::omp::Directive::OMPD_tile:
+        // Skip OMPD_tile since the tile sizes will be retrieved when
+        // generating the omp.loop_nest op.
+        break;
+      default: {
+        unsigned version = semaCtx.langOptions().OpenMPVersion;
+        TODO(currentLocation,
+             "Applying a loop-associated on the loop generated by the " +
+                 llvm::omp::getOpenMPDirectiveName(nestedDirective, version) +
+                 " construct");
+      }
+      }
     }
   }
 
diff --git a/flang/lib/Lower/OpenMP/Utils.cpp b/flang/lib/Lower/OpenMP/Utils.cpp
index 7d7a4869ab3a6..913e4d1e69500 100644
--- a/flang/lib/Lower/OpenMP/Utils.cpp
+++ b/flang/lib/Lower/OpenMP/Utils.cpp
@@ -812,13 +812,14 @@ void collectTileSizesFromOpenMPConstruct(
 
 int64_t collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
   int64_t numCollapse = 1;
 
   // Collect the loops to collapse.
-  lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *doConstructEval = &nestedEval;
   if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
     TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
   }
@@ -830,21 +831,21 @@ int64_t collectLoopRelatedInfo(
     numCollapse = collapseValue;
   }
 
-  collectLoopRelatedInfo(converter, currentLocation, eval, numCollapse, result,
-                         iv);
+  collectLoopRelatedInfo(converter, currentLocation, eval, nestedEval,
+                         numCollapse, result, iv);
   return numCollapse;
 }
 
 void collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, int64_t numCollapse,
-    mlir::omp::LoopRelatedClauseOps &result,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    int64_t numCollapse, mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv) {
 
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
   // Collect the loops to collapse.
-  lower::pft::Evaluation *doConstructEval = &eval.getFirstNestedEvaluation();
+  lower::pft::Evaluation *doConstructEval = &nestedEval;
   if (doConstructEval->getIf<parser::DoConstruct>()->IsDoConcurrent()) {
     TODO(currentLocation, "Do Concurrent in Worksharing loop construct");
   }
@@ -852,10 +853,15 @@ void collectLoopRelatedInfo(
   // Collect sizes from tile directive if present.
   std::int64_t sizesLengthValue = 0l;
   if (auto *ompCons{eval.getIf<parser::OpenMPConstruct>()}) {
-    processTileSizesFromOpenMPConstruct(
-        ompCons, [&](const parser::OmpClause::Sizes *tclause) {
-          sizesLengthValue = tclause->v.size();
-        });
+    if (auto *ompLoop{std::get_if<parser::OpenMPLoopConstruct>(&ompCons->u)}) {
+      const parser::OmpDirectiveSpecification &beginSpec{ompLoop->BeginDir()};
+      if (beginSpec.DirId() == llvm::omp::Directive::OMPD_tile) {
+        processTileSizesFromOpenMPConstruct(
+            ompCons, [&](const parser::OmpClause::Sizes *tclause) {
+              sizesLengthValue = tclause->v.size();
+            });
+      }
+    }
   }
 
   std::int64_t collapseValue = std::max(numCollapse, sizesLengthValue);
diff --git a/flang/lib/Lower/OpenMP/Utils.h b/flang/lib/Lower/OpenMP/Utils.h
index 2960b663b08b2..886a5c1835f7e 100644
--- a/flang/lib/Lower/OpenMP/Utils.h
+++ b/flang/lib/Lower/OpenMP/Utils.h
@@ -169,13 +169,15 @@ void lastprivateModifierNotSupported(const omp::clause::Lastprivate &lastp,
 
 int64_t collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, const omp::List<omp::Clause> &clauses,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
 
 void collectLoopRelatedInfo(
     lower::AbstractConverter &converter, mlir::Location currentLocation,
-    lower::pft::Evaluation &eval, std::int64_t collapseValue,
+    lower::pft::Evaluation &eval, lower::pft::Evaluation &nestedEval,
+    std::int64_t collapseValue,
     // const omp::List<omp::Clause> &clauses,
     mlir::omp::LoopRelatedClauseOps &result,
     llvm::SmallVectorImpl<const semantics::Symbol *> &iv);
diff --git a/flang/lib/Parser/openmp-parsers.cpp b/flang/lib/Parser/openmp-parsers.cpp
index e2da60ed19de8..231eea8841d4b 100644
--- a/flang/lib/Parser/openmp-parsers.cpp
+++ b/flang/lib/Parser/openmp-parsers.cpp
@@ -2260,6 +2260,7 @@ static constexpr DirectiveSet GetLoopDirectives() {
       unsigned(Directive::OMPD_teams_distribute_parallel_do_simd),
       unsigned(Directive::OMPD_teams_distribute_simd),
       unsigned(Directive::OMPD_teams_loop),
+      unsigned(Directive::OMPD_fuse),
       unsigned(Directive::OMPD_tile),
       unsigned(Directive::OMPD_unroll),
   };
diff --git a/flang/lib/Parser/openmp-utils.cpp b/flang/lib/Parser/openmp-utils.cpp
index 2424828293c73..dfe8dbdd5ac9e 100644
--- a/flang/lib/Parser/openmp-utils.cpp
+++ b/flang/lib/Parser/openmp-utils.cpp
@@ -41,6 +41,23 @@ const OpenMPConstruct *GetOmp(const ExecutionPartConstruct &x) {
   return nullptr;
 }
 
+const OpenMPLoopConstruct *GetOmpLoop(const ExecutionPartConstruct &x) {
+  if (auto *construct{GetOmp(x)}) {
+    if (auto *omp{std::get_if<OpenMPLoopConstruct>(&construct->u)}) {
+      return omp;
+    }
+  }
+  return nullptr;
+}
+const DoConstruct *GetDoConstruct(const ExecutionPartConstruct &x) {
+  if (auto *y{std::get_if<ExecutableConstruct>(&x.u)}) {
+    if (auto *z{std::get_if<common::Indirection<DoConstruct>>(&y->u)}) {
+      return &z->value();
+    }
+  }
+  return nullptr;
+}
+
 const OmpObjectList *GetOmpObjectList(const OmpClause &clause) {
   // Clauses with OmpObjectList as its data member
   using MemberObjectListClauses = std::tuple<OmpClause::Copyin,
diff --git a/flang/lib/Semantics/canonicalize-omp.cpp b/flang/lib/Semantics/canonicalize-omp.cpp
index 0cec1969e0978..f7c53d6d8f4c4 100644
--- a/flang/lib/Semantics/canonicalize-omp.cpp
+++ b/flang/lib/Semantics/canonicalize-omp.cpp
@@ -9,6 +9,7 @@
 #include "canonicalize-omp.h"
 #include "flang/Parser/parse-tree-visitor.h"
 #include "flang/Parser/parse-tree.h"
+#include "flang/Semantics/openmp-directive-sets.h"
 #include "flang/Semantics/semantics.h"
 
 // After Loop Canonicalization, rewrite OpenMP parse tree to make OpenMP
@@ -136,20 +137,30 @@ class CanonicalizationOfOmp {
           "A DO loop must follow the %s directive"_err_en_US,
           parser::ToUpperCaseLetters(dirName.source.ToString()));
     };
-    auto tileUnrollError = [](const parser::OmpDirectiveName &dirName,
-                               parser::Messages &messages) {
+    auto transformUnrollError = [](const parser::OmpDirectiveName &dirName,
+                                    parser::Messages &messages) {
       messages.Say(dirName.source,
-          "If a loop construct has been fully unrolled, it cannot then be tiled"_err_en_US,
+          "If a loop construct has been fully unrolled, it cannot then be further transformed"_err_en_US,
           parser::ToUpperCaseLetters(dirName.source.ToString()));
     };
+    auto missingEndFuse = [](auto &dir, auto &messages) {
+      messages.Say(dir.source,
+          "The %s construct requires the END FUSE directive"_err_en_US,
+          parser::ToUpperCaseLetters(dir.source.ToString()));
+    };
+
+    bool endFuseNeeded = beginName.v == llvm::omp::Directive::OMPD_fuse;
 
     auto &body{std::get<parser::Block>(x.t)};
 
     nextIt = it;
-    while (++nextIt != block.end()) {
+    nextIt++;
+    while (nextIt != block.end()) {
       // Ignore compiler directives.
-      if (GetConstructIf<parser::CompilerDirective>(*nextIt))
+      if (GetConstructIf<parser::CompilerDirective>(*nextIt)) {
+        nextIt++;
         continue;
+      }
 
       if (auto *doCons{GetConstructIf<parser::DoConstruct>(*nextIt)}) {
         if (doCons->GetLoopControl()) {
@@ -160,9 +171,12 @@ class CanonicalizationOfOmp {
           if (nextIt != block.end()) {
             if (auto *endDir{
                     GetConstructIf<parser::OmpEndLoopDirective>(*nextIt)}) {
-              std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
-                  std::move(*endDir);
-              nextIt = block.erase(nextIt);
+              auto &endDirName = endDir->DirName();
+              if (endDirName.v != llvm::omp::Directive::OMPD_fuse) {
+                std::get<std::optional<parser::OmpEndLoopDirective>>(x.t) =
+                    std::move(*endDir);
+                nextIt = block.erase(nextIt);
+              }
             }
           }
         } else {
@@ -172,50 +186,45 @@ class CanonicalizationOfOmp {
         }
       } else if (auto *ompLoopCons{
                      GetOmpIf<parser::OpenMPLoopConstruct>(*nextIt)}) {
-        // We should allow UNROLL and TILE constructs to be inserted between an
-        // OpenMP Loop Construct and the DO loop itself
+        // We should allow loop transformation constructs to be inserted between
+   ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/168898


More information about the llvm-commits mailing list