[flang-commits] [flang] [flang][OpenMP] Rewrite `omp.loop` to semantically equivalent ops (PR #115443)

Kareem Ergawy via flang-commits flang-commits at lists.llvm.org
Thu Nov 21 20:41:14 PST 2024


https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/115443

>From a459f4c068e7fb03526b5e83479783abb6d8a15e Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Thu, 7 Nov 2024 04:26:57 -0600
Subject: [PATCH] [flang][OpenMP] Rewrite `omp.loop` to semantically equivalent
 ops

Introduces a new conversion pass that rewrites `omp.loop` ops to their
semantically equivalent op nests bases on the surrounding/binding
context of the `loop` op. Not all forms of `omp.loop` are supported yet.
---
 flang/include/flang/Common/OpenMP-utils.h     |  68 +++++++
 .../include/flang/Optimizer/OpenMP/Passes.td  |  21 +++
 flang/lib/Common/CMakeLists.txt               |   4 +
 flang/lib/Common/OpenMP-utils.cpp             |  47 +++++
 flang/lib/Lower/OpenMP/OpenMP.cpp             | 107 ++---------
 flang/lib/Optimizer/OpenMP/CMakeLists.txt     |   2 +
 .../OpenMP/GenericLoopConversion.cpp          | 168 ++++++++++++++++++
 flang/lib/Optimizer/Passes/Pipelines.cpp      |   1 +
 .../Lower/OpenMP/generic-loop-rewriting.f90   |  41 +++++
 .../Transforms/generic-loop-rewriting.mlir    |  57 ++++++
 10 files changed, 419 insertions(+), 97 deletions(-)
 create mode 100644 flang/include/flang/Common/OpenMP-utils.h
 create mode 100644 flang/lib/Common/OpenMP-utils.cpp
 create mode 100644 flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
 create mode 100644 flang/test/Lower/OpenMP/generic-loop-rewriting.f90
 create mode 100644 flang/test/Transforms/generic-loop-rewriting.mlir

diff --git a/flang/include/flang/Common/OpenMP-utils.h b/flang/include/flang/Common/OpenMP-utils.h
new file mode 100644
index 00000000000000..7dbb0f612b19cd
--- /dev/null
+++ b/flang/include/flang/Common/OpenMP-utils.h
@@ -0,0 +1,68 @@
+//===-- include/flang/Common/OpenMP-utils.h --------------------*- C++ -*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_COMMON_OPENMP_UTILS_H_
+#define FORTRAN_COMMON_OPENMP_UTILS_H_
+
+#include "flang/Semantics/symbol.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+
+#include "llvm/ADT/ArrayRef.h"
+
+namespace Fortran::openmp::common {
+/// Structure holding the information needed to create and bind entry block
+/// arguments associated to a single clause.
+struct EntryBlockArgsEntry {
+  llvm::ArrayRef<const Fortran::semantics::Symbol *> syms;
+  llvm::ArrayRef<mlir::Value> vars;
+
+  bool isValid() const {
+    // This check allows specifying a smaller number of symbols than values
+    // because in some case cases a single symbol generates multiple block
+    // arguments.
+    return syms.size() <= vars.size();
+  }
+};
+
+/// Structure holding the information needed to create and bind entry block
+/// arguments associated to all clauses that can define them.
+struct EntryBlockArgs {
+  EntryBlockArgsEntry inReduction;
+  EntryBlockArgsEntry map;
+  EntryBlockArgsEntry priv;
+  EntryBlockArgsEntry reduction;
+  EntryBlockArgsEntry taskReduction;
+  EntryBlockArgsEntry useDeviceAddr;
+  EntryBlockArgsEntry useDevicePtr;
+
+  bool isValid() const {
+    return inReduction.isValid() && map.isValid() && priv.isValid() &&
+        reduction.isValid() && taskReduction.isValid() &&
+        useDeviceAddr.isValid() && useDevicePtr.isValid();
+  }
+
+  auto getSyms() const {
+    return llvm::concat<const Fortran::semantics::Symbol *const>(
+        inReduction.syms, map.syms, priv.syms, reduction.syms,
+        taskReduction.syms, useDeviceAddr.syms, useDevicePtr.syms);
+  }
+
+  auto getVars() const {
+    return llvm::concat<const mlir::Value>(inReduction.vars, map.vars,
+        priv.vars, reduction.vars, taskReduction.vars, useDeviceAddr.vars,
+        useDevicePtr.vars);
+  }
+};
+
+mlir::Block *genEntryBlock(
+    mlir::OpBuilder &builder, const EntryBlockArgs &args, mlir::Region &region);
+} // namespace Fortran::openmp::common
+
+#endif // FORTRAN_COMMON_OPENMP_UTILS_H_
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 37977334c1e9ed..c89b3485ae96de 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -50,9 +50,30 @@ def FunctionFilteringPass : Pass<"omp-function-filtering"> {
   ];
 }
 
+
 // Needs to be scheduled on Module as we create functions in it
 def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
   let summary = "Lower workshare construct";
 }
 
+def GenericLoopConversionPass
+    : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
+  let summary = "Converts OpenMP generic `loop` directive to semantically "
+                "equivalent OpenMP ops";
+  let description = [{
+     Rewrites `loop` ops to their semantically equivalent nest of ops. The
+     rewrite depends on the nesting/combination structure of the `loop` op
+     within its surrounding context as well as its `bind` clause value.
+
+     We assume for now that all `omp.loop` ops will occur inside `FuncOp`'s. This
+     will most likely remain the case in the future; even if, for example, we 
+     need a loop in copying data for a `firstprivate` variable, this loop will
+     be nested in a constructor, an overloaded operator, or a runtime function.
+  }];
+  let dependentDialects = [
+    "mlir::func::FuncDialect",
+    "mlir::omp::OpenMPDialect"
+  ];
+}
+
 #endif //FORTRAN_OPTIMIZER_OPENMP_PASSES
diff --git a/flang/lib/Common/CMakeLists.txt b/flang/lib/Common/CMakeLists.txt
index be72391847f3dd..de6bea396f3cbe 100644
--- a/flang/lib/Common/CMakeLists.txt
+++ b/flang/lib/Common/CMakeLists.txt
@@ -40,9 +40,13 @@ add_flang_library(FortranCommon
   default-kinds.cpp
   idioms.cpp
   LangOptions.cpp
+  OpenMP-utils.cpp
   Version.cpp
   ${version_inc}
 
   LINK_COMPONENTS
   Support
+
+  LINK_LIBS
+  MLIRIR
 )
diff --git a/flang/lib/Common/OpenMP-utils.cpp b/flang/lib/Common/OpenMP-utils.cpp
new file mode 100644
index 00000000000000..32df2be01e5484
--- /dev/null
+++ b/flang/lib/Common/OpenMP-utils.cpp
@@ -0,0 +1,47 @@
+//===-- include/flang/Common/OpenMP-utils.cpp ------------------*- C++ -*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/OpenMP-utils.h"
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace Fortran::openmp::common {
+mlir::Block *genEntryBlock(mlir::OpBuilder &builder, const EntryBlockArgs &args,
+    mlir::Region &region) {
+  assert(args.isValid() && "invalid args");
+  assert(region.empty() && "non-empty region");
+
+  llvm::SmallVector<mlir::Type> types;
+  llvm::SmallVector<mlir::Location> locs;
+  unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
+      args.priv.vars.size() + args.reduction.vars.size() +
+      args.taskReduction.vars.size() + args.useDeviceAddr.vars.size() +
+      args.useDevicePtr.vars.size();
+  types.reserve(numVars);
+  locs.reserve(numVars);
+
+  auto extractTypeLoc = [&types, &locs](llvm::ArrayRef<mlir::Value> vals) {
+    llvm::transform(vals, std::back_inserter(types),
+        [](mlir::Value v) { return v.getType(); });
+    llvm::transform(vals, std::back_inserter(locs),
+        [](mlir::Value v) { return v.getLoc(); });
+  };
+
+  // Populate block arguments in clause name alphabetical order to match
+  // expected order by the BlockArgOpenMPOpInterface.
+  extractTypeLoc(args.inReduction.vars);
+  extractTypeLoc(args.map.vars);
+  extractTypeLoc(args.priv.vars);
+  extractTypeLoc(args.reduction.vars);
+  extractTypeLoc(args.taskReduction.vars);
+  extractTypeLoc(args.useDeviceAddr.vars);
+  extractTypeLoc(args.useDevicePtr.vars);
+
+  return builder.createBlock(&region, {}, types, locs);
+}
+} // namespace Fortran::openmp::common
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index a2779213a1a15a..48b499cae2681c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -19,6 +19,7 @@
 #include "DirectivesCommon.h"
 #include "ReductionProcessor.h"
 #include "Utils.h"
+#include "flang/Common/OpenMP-utils.h"
 #include "flang/Common/idioms.h"
 #include "flang/Lower/Bridge.h"
 #include "flang/Lower/ConvertExpr.h"
@@ -41,57 +42,12 @@
 #include "llvm/Frontend/OpenMP/OMPConstants.h"
 
 using namespace Fortran::lower::omp;
+using namespace Fortran::openmp::common;
 
 //===----------------------------------------------------------------------===//
 // Code generation helper functions
 //===----------------------------------------------------------------------===//
 
-namespace {
-/// Structure holding the information needed to create and bind entry block
-/// arguments associated to a single clause.
-struct EntryBlockArgsEntry {
-  llvm::ArrayRef<const semantics::Symbol *> syms;
-  llvm::ArrayRef<mlir::Value> vars;
-
-  bool isValid() const {
-    // This check allows specifying a smaller number of symbols than values
-    // because in some case cases a single symbol generates multiple block
-    // arguments.
-    return syms.size() <= vars.size();
-  }
-};
-
-/// Structure holding the information needed to create and bind entry block
-/// arguments associated to all clauses that can define them.
-struct EntryBlockArgs {
-  EntryBlockArgsEntry inReduction;
-  EntryBlockArgsEntry map;
-  EntryBlockArgsEntry priv;
-  EntryBlockArgsEntry reduction;
-  EntryBlockArgsEntry taskReduction;
-  EntryBlockArgsEntry useDeviceAddr;
-  EntryBlockArgsEntry useDevicePtr;
-
-  bool isValid() const {
-    return inReduction.isValid() && map.isValid() && priv.isValid() &&
-           reduction.isValid() && taskReduction.isValid() &&
-           useDeviceAddr.isValid() && useDevicePtr.isValid();
-  }
-
-  auto getSyms() const {
-    return llvm::concat<const semantics::Symbol *const>(
-        inReduction.syms, map.syms, priv.syms, reduction.syms,
-        taskReduction.syms, useDeviceAddr.syms, useDevicePtr.syms);
-  }
-
-  auto getVars() const {
-    return llvm::concat<const mlir::Value>(
-        inReduction.vars, map.vars, priv.vars, reduction.vars,
-        taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars);
-  }
-};
-} // namespace
-
 static void genOMPDispatch(lower::AbstractConverter &converter,
                            lower::SymMap &symTable,
                            semantics::SemanticsContext &semaCtx,
@@ -623,50 +579,6 @@ static void genLoopVars(
   firOpBuilder.setInsertionPointAfter(storeOp);
 }
 
-/// Create an entry block for the given region, including the clause-defined
-/// arguments specified.
-///
-/// \param [in] converter - PFT to MLIR conversion interface.
-/// \param [in]      args - entry block arguments information for the given
-///                         operation.
-/// \param [in]    region - Empty region in which to create the entry block.
-static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
-                                  const EntryBlockArgs &args,
-                                  mlir::Region &region) {
-  assert(args.isValid() && "invalid args");
-  assert(region.empty() && "non-empty region");
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  llvm::SmallVector<mlir::Type> types;
-  llvm::SmallVector<mlir::Location> locs;
-  unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
-                     args.priv.vars.size() + args.reduction.vars.size() +
-                     args.taskReduction.vars.size() +
-                     args.useDeviceAddr.vars.size() +
-                     args.useDevicePtr.vars.size();
-  types.reserve(numVars);
-  locs.reserve(numVars);
-
-  auto extractTypeLoc = [&types, &locs](llvm::ArrayRef<mlir::Value> vals) {
-    llvm::transform(vals, std::back_inserter(types),
-                    [](mlir::Value v) { return v.getType(); });
-    llvm::transform(vals, std::back_inserter(locs),
-                    [](mlir::Value v) { return v.getLoc(); });
-  };
-
-  // Populate block arguments in clause name alphabetical order to match
-  // expected order by the BlockArgOpenMPOpInterface.
-  extractTypeLoc(args.inReduction.vars);
-  extractTypeLoc(args.map.vars);
-  extractTypeLoc(args.priv.vars);
-  extractTypeLoc(args.reduction.vars);
-  extractTypeLoc(args.taskReduction.vars);
-  extractTypeLoc(args.useDeviceAddr.vars);
-  extractTypeLoc(args.useDevicePtr.vars);
-
-  return firOpBuilder.createBlock(&region, {}, types, locs);
-}
-
 static void
 markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
                   mlir::omp::DeclareTargetCaptureClause captureClause,
@@ -919,7 +831,7 @@ static void genBodyOfTargetDataOp(
     ConstructQueue::const_iterator item) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
 
-  genEntryBlock(converter, args, dataOp.getRegion());
+  genEntryBlock(converter.getFirOpBuilder(), args, dataOp.getRegion());
   bindEntryBlockArgs(converter, dataOp, args);
 
   // Insert dummy instruction to remember the insertion position. The
@@ -996,7 +908,8 @@ static void genBodyOfTargetOp(
   auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp);
 
   mlir::Region &region = targetOp.getRegion();
-  mlir::Block *entryBlock = genEntryBlock(converter, args, region);
+  mlir::Block *entryBlock =
+      genEntryBlock(converter.getFirOpBuilder(), args, region);
   bindEntryBlockArgs(converter, targetOp, args);
 
   // Check if cloning the bounds introduced any dependency on the outer region.
@@ -1122,7 +1035,7 @@ static OpTy genWrapperOp(lower::AbstractConverter &converter,
   auto op = firOpBuilder.create<OpTy>(loc, clauseOps);
 
   // Create entry block with arguments.
-  genEntryBlock(converter, args, op.getRegion());
+  genEntryBlock(converter.getFirOpBuilder(), args, op.getRegion());
 
   return op;
 }
@@ -1588,7 +1501,7 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
               const EntryBlockArgs &args, DataSharingProcessor *dsp,
               bool isComposite = false) {
   auto genRegionEntryCB = [&](mlir::Operation *op) {
-    genEntryBlock(converter, args, op->getRegion(0));
+    genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0));
     bindEntryBlockArgs(
         converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
     return llvm::to_vector(args.getSyms());
@@ -1661,12 +1574,12 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   args.reduction.syms = reductionSyms;
   args.reduction.vars = clauseOps.reductionVars;
 
-  genEntryBlock(converter, args, sectionsOp.getRegion());
+  genEntryBlock(converter.getFirOpBuilder(), args, sectionsOp.getRegion());
   mlir::Operation *terminator =
       lower::genOpenMPTerminator(builder, sectionsOp, loc);
 
   auto genRegionEntryCB = [&](mlir::Operation *op) {
-    genEntryBlock(converter, args, op->getRegion(0));
+    genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0));
     bindEntryBlockArgs(
         converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
     return llvm::to_vector(args.getSyms());
@@ -1989,7 +1902,7 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
   taskArgs.priv.vars = clauseOps.privateVars;
 
   auto genRegionEntryCB = [&](mlir::Operation *op) {
-    genEntryBlock(converter, taskArgs, op->getRegion(0));
+    genEntryBlock(converter.getFirOpBuilder(), taskArgs, op->getRegion(0));
     bindEntryBlockArgs(converter,
                        llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
                        taskArgs);
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index b1e0dbf6e707e5..51ecbe1a664f92 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
 
 add_flang_library(FlangOpenMPTransforms
   FunctionFiltering.cpp
+  GenericLoopConversion.cpp
   MapsForPrivatizedSymbols.cpp
   MapInfoFinalization.cpp
   MarkDeclareTarget.cpp
@@ -25,4 +26,5 @@ add_flang_library(FlangOpenMPTransforms
   HLFIRDialect
   MLIRIR
   MLIRPass
+  MLIRTransformUtils
 )
diff --git a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
new file mode 100644
index 00000000000000..cf001a203703e1
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
@@ -0,0 +1,168 @@
+//===- GenericLoopConversion.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/OpenMP-utils.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "flang/Semantics/symbol.h"
+
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include <memory>
+
+namespace flangomp {
+#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+namespace {
+
+/// A conversion pattern to handle various combined forms of `omp.loop`. For how
+/// combined/composite directive are handled see:
+/// https://discourse.llvm.org/t/rfc-representing-combined-composite-constructs-in-the-openmp-dialect/76986.
+class GenericLoopConversionPattern
+    : public mlir::OpConversionPattern<mlir::omp::LoopOp> {
+public:
+  enum class GenericLoopCombinedInfo {
+    None,
+    TargetTeamsLoop,
+    TargetParallelLoop
+  };
+
+  using mlir::OpConversionPattern<mlir::omp::LoopOp>::OpConversionPattern;
+
+  mlir::LogicalResult
+  matchAndRewrite(mlir::omp::LoopOp loopOp, OpAdaptor adaptor,
+                  mlir::ConversionPatternRewriter &rewriter) const override {
+    assert(isLoopConversionSupported(loopOp));
+
+    rewriteToDistributeParallelDo(loopOp, rewriter);
+    rewriter.eraseOp(loopOp);
+    return mlir::success();
+  }
+
+  static bool isLoopConversionSupported(mlir::omp::LoopOp loopOp) {
+    GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp);
+
+    // TODO Support standalone `loop` ops and other forms of combined `loop` op
+    // nests.
+    if (combinedInfo != GenericLoopCombinedInfo::TargetTeamsLoop)
+      return false;
+
+    // TODO Support other clauses.
+    if (loopOp.getBindKind() || loopOp.getOrder() ||
+        !loopOp.getReductionVars().empty())
+      return false;
+
+    // TODO For `target teams loop`, check similar constrains to what is checked
+    // by `TeamsLoopChecker` in SemaOpenMP.cpp.
+    return true;
+  }
+
+private:
+  static GenericLoopCombinedInfo
+  findGenericLoopCombineInfo(mlir::omp::LoopOp loopOp) {
+    mlir::Operation *parentOp = loopOp->getParentOp();
+    GenericLoopCombinedInfo result = GenericLoopCombinedInfo::None;
+
+    if (auto teamsOp = mlir::dyn_cast_if_present<mlir::omp::TeamsOp>(parentOp))
+      if (mlir::isa<mlir::omp::TargetOp>(teamsOp->getParentOp()))
+        result = GenericLoopCombinedInfo::TargetTeamsLoop;
+
+    if (auto parallelOp =
+            mlir::dyn_cast_if_present<mlir::omp::ParallelOp>(parentOp))
+      if (mlir::isa<mlir::omp::TargetOp>(parallelOp->getParentOp()))
+        result = GenericLoopCombinedInfo::TargetParallelLoop;
+
+    return result;
+  }
+
+  void rewriteToDistributeParallelDo(
+      mlir::omp::LoopOp loopOp,
+      mlir::ConversionPatternRewriter &rewriter) const {
+    mlir::omp::ParallelOperands parallelClauseOps;
+    parallelClauseOps.privateVars = loopOp.getPrivateVars();
+
+    auto privateSyms = loopOp.getPrivateSyms();
+    if (privateSyms)
+      parallelClauseOps.privateSyms.assign(privateSyms->begin(),
+                                           privateSyms->end());
+
+    Fortran::openmp::common::EntryBlockArgs parallelArgs;
+    parallelArgs.priv.vars = parallelClauseOps.privateVars;
+
+    auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
+                                                             parallelClauseOps);
+    mlir::Block *parallelBlock =
+        genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
+    parallelOp.setComposite(true);
+    rewriter.setInsertionPoint(
+        rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
+
+    mlir::omp::DistributeOperands distributeClauseOps;
+    auto distributeOp = rewriter.create<mlir::omp::DistributeOp>(
+        loopOp.getLoc(), distributeClauseOps);
+    distributeOp.setComposite(true);
+    rewriter.createBlock(&distributeOp.getRegion());
+
+    mlir::omp::WsloopOperands wsloopClauseOps;
+    auto wsloopOp =
+        rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
+    wsloopOp.setComposite(true);
+    rewriter.createBlock(&wsloopOp.getRegion());
+
+    mlir::IRMapping mapper;
+    mlir::Block &loopBlock = *loopOp.getRegion().begin();
+
+    for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
+             loopBlock.getArguments(), parallelBlock->getArguments()))
+      mapper.map(loopOpArg, parallelOpArg);
+
+    rewriter.clone(*loopOp.begin(), mapper);
+  }
+};
+
+class GenericLoopConversionPass
+    : public flangomp::impl::GenericLoopConversionPassBase<
+          GenericLoopConversionPass> {
+public:
+  GenericLoopConversionPass() = default;
+
+  void runOnOperation() override {
+    mlir::func::FuncOp func = getOperation();
+
+    if (func.isDeclaration()) {
+      return;
+    }
+
+    mlir::MLIRContext *context = &getContext();
+    mlir::RewritePatternSet patterns(context);
+    patterns.insert<GenericLoopConversionPattern>(context);
+    mlir::ConversionTarget target(*context);
+
+    target.markUnknownOpDynamicallyLegal(
+        [](mlir::Operation *) { return true; });
+    target.addDynamicallyLegalOp<mlir::omp::LoopOp>(
+        [](mlir::omp::LoopOp loopOp) {
+          return !GenericLoopConversionPattern::isLoopConversionSupported(
+              loopOp);
+        });
+
+    if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
+                                               std::move(patterns)))) {
+      mlir::emitError(func.getLoc(), "error in converting `omp.loop` op");
+      signalPassFailure();
+    }
+  }
+};
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 31af3531641dda..1568165bcd64a4 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -247,6 +247,7 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm, bool isTargetDevice) {
   pm.addPass(flangomp::createMapInfoFinalizationPass());
   pm.addPass(flangomp::createMapsForPrivatizedSymbolsPass());
   pm.addPass(flangomp::createMarkDeclareTargetPass());
+  pm.addPass(flangomp::createGenericLoopConversionPass());
   if (isTargetDevice)
     pm.addPass(flangomp::createFunctionFilteringPass());
 }
diff --git a/flang/test/Lower/OpenMP/generic-loop-rewriting.f90 b/flang/test/Lower/OpenMP/generic-loop-rewriting.f90
new file mode 100644
index 00000000000000..fa26425356dd90
--- /dev/null
+++ b/flang/test/Lower/OpenMP/generic-loop-rewriting.f90
@@ -0,0 +1,41 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+subroutine target_teams_loop
+    implicit none
+    integer :: x, i
+
+    !$omp target teams loop
+    do i = 0, 10
+      x = x + i
+    end do
+end subroutine target_teams_loop
+
+!CHECK-LABEL: func.func @_QPtarget_teams_loop
+!CHECK:         omp.target map_entries(
+!CHECK-SAME:      %{{.*}} -> %[[I_ARG:[^[:space:]]+]],
+!CHECK-SAME:      %{{.*}} -> %[[X_ARG:[^[:space:]]+]] : {{.*}}) {
+
+!CHECK:           %[[I_DECL:.*]]:2 = hlfir.declare %[[I_ARG]]
+!CHECK:           %[[X_DECL:.*]]:2 = hlfir.declare %[[X_ARG]]
+
+!CHECK:           omp.teams {
+
+!CHECK:             %[[LB:.*]] = arith.constant 0 : i32
+!CHECK:             %[[UB:.*]] = arith.constant 10 : i32
+!CHECK:             %[[STEP:.*]] = arith.constant 1 : i32
+
+!CHECK:             omp.parallel private(@{{.*}} %[[I_DECL]]#0 
+!CHECK-SAME:          -> %[[I_PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>) {
+!CHECK:               omp.distribute {
+!CHECK:                 omp.wsloop {
+
+!CHECK:                   omp.loop_nest (%{{.*}}) : i32 = 
+!CHECK-SAME:                (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) {
+!CHECK:                     %[[I_PRIV_DECL:.*]]:2 = hlfir.declare %[[I_PRIV_ARG]]
+!CHECK:                     fir.store %{{.*}} to %[[I_PRIV_DECL]]#1 : !fir.ref<i32>
+!CHECK:                   }
+!CHECK:                 }
+!CHECK:               }
+!CHECK:             }
+!CHECK:           }
+!CHECK:         }
diff --git a/flang/test/Transforms/generic-loop-rewriting.mlir b/flang/test/Transforms/generic-loop-rewriting.mlir
new file mode 100644
index 00000000000000..a18ea9853602ac
--- /dev/null
+++ b/flang/test/Transforms/generic-loop-rewriting.mlir
@@ -0,0 +1,57 @@
+// RUN: fir-opt --omp-generic-loop-conversion %s | FileCheck %s
+
+omp.private {type = private} @_QFtarget_teams_loopEi_private_ref_i32 : !fir.ref<i32> alloc {
+^bb0(%arg0: !fir.ref<i32>):
+  omp.yield(%arg0 : !fir.ref<i32>)
+}
+
+func.func @_QPtarget_teams_loop() {
+  %i = fir.alloca i32
+  %i_map = omp.map.info var_ptr(%i : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
+  omp.target map_entries(%i_map -> %arg0 : !fir.ref<i32>) {
+    omp.teams {
+      %c0 = arith.constant 0 : i32
+      %c10 = arith.constant 10 : i32
+      %c1 = arith.constant 1 : i32
+      omp.loop private(@_QFtarget_teams_loopEi_private_ref_i32 %arg0 -> %arg2 : !fir.ref<i32>) {
+        omp.loop_nest (%arg3) : i32 = (%c0) to (%c10) inclusive step (%c1) {
+          fir.store %arg3 to %arg2 : !fir.ref<i32>
+          omp.yield
+        }
+      }
+      omp.terminator
+    }
+    omp.terminator
+  }
+  return
+}
+
+// CHECK-LABEL: func.func @_QPtarget_teams_loop
+// CHECK:         omp.target map_entries(
+// CHECK-SAME:      %{{.*}} -> %[[I_ARG:[^[:space:]]+]] : {{.*}}) {
+// 
+// CHECK:           omp.teams {
+// 
+// TODO we probably need to move the `loop_nest` bounds ops from the `teams`
+// region to the `parallel` region to avoid making these values `shared`. We can
+// find the backward slices of these bounds that are within the `teams` region
+// and move these slices to the `parallel` op.
+
+// CHECK:             %[[LB:.*]] = arith.constant 0 : i32
+// CHECK:             %[[UB:.*]] = arith.constant 10 : i32
+// CHECK:             %[[STEP:.*]] = arith.constant 1 : i32
+// 
+// CHECK:             omp.parallel private(@{{.*}} %[[I_ARG]]
+// CHECK-SAME:          -> %[[I_PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>) {
+// CHECK:               omp.distribute {
+// CHECK:                 omp.wsloop {
+// 
+// CHECK:                   omp.loop_nest (%{{.*}}) : i32 = 
+// CHECK-SAME:                (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) {
+// CHECK:                     fir.store %{{.*}} to %[[I_PRIV_ARG]] : !fir.ref<i32>
+// CHECK:                   }
+// CHECK:                 }
+// CHECK:               }
+// CHECK:             }
+// CHECK:           }
+// CHECK:         }



More information about the flang-commits mailing list