[flang-commits] [flang] [flang][OpenMP] Rewrite `omp.loop` to semantically equivalent ops (PR #115443)
Kareem Ergawy via flang-commits
flang-commits at lists.llvm.org
Thu Nov 21 20:36:50 PST 2024
https://github.com/ergawy updated https://github.com/llvm/llvm-project/pull/115443
>From 032afb5843f312a39c4a5ccf95fbc2e982116732 Mon Sep 17 00:00:00 2001
From: ergawy <kareem.ergawy at amd.com>
Date: Thu, 7 Nov 2024 04:26:57 -0600
Subject: [PATCH] [flang][OpenMP] Rewrite `omp.loop` to semantically equivalent
ops
Introduces a new conversion pass that rewrites `omp.loop` ops to their
semantically equivalent op nests bases on the surrounding/binding
context of the `loop` op. Not all forms of `omp.loop` are supported yet.
---
flang/include/flang/Common/OpenMP-utils.h | 68 +++++++
.../include/flang/Optimizer/OpenMP/Passes.td | 16 ++
flang/lib/Common/CMakeLists.txt | 4 +
flang/lib/Common/OpenMP-utils.cpp | 47 +++++
flang/lib/Lower/OpenMP/OpenMP.cpp | 107 ++---------
flang/lib/Optimizer/OpenMP/CMakeLists.txt | 2 +
.../OpenMP/GenericLoopConversion.cpp | 168 ++++++++++++++++++
flang/lib/Optimizer/Passes/Pipelines.cpp | 1 +
.../Lower/OpenMP/generic-loop-rewriting.f90 | 41 +++++
.../Transforms/generic-loop-rewriting.mlir | 57 ++++++
10 files changed, 414 insertions(+), 97 deletions(-)
create mode 100644 flang/include/flang/Common/OpenMP-utils.h
create mode 100644 flang/lib/Common/OpenMP-utils.cpp
create mode 100644 flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
create mode 100644 flang/test/Lower/OpenMP/generic-loop-rewriting.f90
create mode 100644 flang/test/Transforms/generic-loop-rewriting.mlir
diff --git a/flang/include/flang/Common/OpenMP-utils.h b/flang/include/flang/Common/OpenMP-utils.h
new file mode 100644
index 00000000000000..7dbb0f612b19cd
--- /dev/null
+++ b/flang/include/flang/Common/OpenMP-utils.h
@@ -0,0 +1,68 @@
+//===-- include/flang/Common/OpenMP-utils.h --------------------*- C++ -*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_COMMON_OPENMP_UTILS_H_
+#define FORTRAN_COMMON_OPENMP_UTILS_H_
+
+#include "flang/Semantics/symbol.h"
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Value.h"
+
+#include "llvm/ADT/ArrayRef.h"
+
+namespace Fortran::openmp::common {
+/// Structure holding the information needed to create and bind entry block
+/// arguments associated to a single clause.
+struct EntryBlockArgsEntry {
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> syms;
+ llvm::ArrayRef<mlir::Value> vars;
+
+ bool isValid() const {
+ // This check allows specifying a smaller number of symbols than values
+ // because in some case cases a single symbol generates multiple block
+ // arguments.
+ return syms.size() <= vars.size();
+ }
+};
+
+/// Structure holding the information needed to create and bind entry block
+/// arguments associated to all clauses that can define them.
+struct EntryBlockArgs {
+ EntryBlockArgsEntry inReduction;
+ EntryBlockArgsEntry map;
+ EntryBlockArgsEntry priv;
+ EntryBlockArgsEntry reduction;
+ EntryBlockArgsEntry taskReduction;
+ EntryBlockArgsEntry useDeviceAddr;
+ EntryBlockArgsEntry useDevicePtr;
+
+ bool isValid() const {
+ return inReduction.isValid() && map.isValid() && priv.isValid() &&
+ reduction.isValid() && taskReduction.isValid() &&
+ useDeviceAddr.isValid() && useDevicePtr.isValid();
+ }
+
+ auto getSyms() const {
+ return llvm::concat<const Fortran::semantics::Symbol *const>(
+ inReduction.syms, map.syms, priv.syms, reduction.syms,
+ taskReduction.syms, useDeviceAddr.syms, useDevicePtr.syms);
+ }
+
+ auto getVars() const {
+ return llvm::concat<const mlir::Value>(inReduction.vars, map.vars,
+ priv.vars, reduction.vars, taskReduction.vars, useDeviceAddr.vars,
+ useDevicePtr.vars);
+ }
+};
+
+mlir::Block *genEntryBlock(
+ mlir::OpBuilder &builder, const EntryBlockArgs &args, mlir::Region ®ion);
+} // namespace Fortran::openmp::common
+
+#endif // FORTRAN_COMMON_OPENMP_UTILS_H_
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 37977334c1e9ed..05a910c056876d 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -50,9 +50,25 @@ def FunctionFilteringPass : Pass<"omp-function-filtering"> {
];
}
+
// Needs to be scheduled on Module as we create functions in it
def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
let summary = "Lower workshare construct";
}
+def GenericLoopConversionPass
+ : Pass<"omp-generic-loop-conversion", "mlir::func::FuncOp"> {
+ let summary = "Converts OpenMP generic `loop` directive to semantically "
+ "equivalent OpenMP ops";
+ let description = [{
+ Rewrites `loop` ops to their semantically equivalent nest of ops. The
+ rewrite depends on the nesting/combination structure of the `loop` op
+ within its surrounding context as well as its `bind` clause value.
+ }];
+ let dependentDialects = [
+ "mlir::func::FuncDialect",
+ "mlir::omp::OpenMPDialect"
+ ];
+}
+
#endif //FORTRAN_OPTIMIZER_OPENMP_PASSES
diff --git a/flang/lib/Common/CMakeLists.txt b/flang/lib/Common/CMakeLists.txt
index be72391847f3dd..de6bea396f3cbe 100644
--- a/flang/lib/Common/CMakeLists.txt
+++ b/flang/lib/Common/CMakeLists.txt
@@ -40,9 +40,13 @@ add_flang_library(FortranCommon
default-kinds.cpp
idioms.cpp
LangOptions.cpp
+ OpenMP-utils.cpp
Version.cpp
${version_inc}
LINK_COMPONENTS
Support
+
+ LINK_LIBS
+ MLIRIR
)
diff --git a/flang/lib/Common/OpenMP-utils.cpp b/flang/lib/Common/OpenMP-utils.cpp
new file mode 100644
index 00000000000000..32df2be01e5484
--- /dev/null
+++ b/flang/lib/Common/OpenMP-utils.cpp
@@ -0,0 +1,47 @@
+//===-- include/flang/Common/OpenMP-utils.cpp ------------------*- C++ -*-====//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/OpenMP-utils.h"
+
+#include "mlir/IR/OpDefinition.h"
+
+namespace Fortran::openmp::common {
+mlir::Block *genEntryBlock(mlir::OpBuilder &builder, const EntryBlockArgs &args,
+ mlir::Region ®ion) {
+ assert(args.isValid() && "invalid args");
+ assert(region.empty() && "non-empty region");
+
+ llvm::SmallVector<mlir::Type> types;
+ llvm::SmallVector<mlir::Location> locs;
+ unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
+ args.priv.vars.size() + args.reduction.vars.size() +
+ args.taskReduction.vars.size() + args.useDeviceAddr.vars.size() +
+ args.useDevicePtr.vars.size();
+ types.reserve(numVars);
+ locs.reserve(numVars);
+
+ auto extractTypeLoc = [&types, &locs](llvm::ArrayRef<mlir::Value> vals) {
+ llvm::transform(vals, std::back_inserter(types),
+ [](mlir::Value v) { return v.getType(); });
+ llvm::transform(vals, std::back_inserter(locs),
+ [](mlir::Value v) { return v.getLoc(); });
+ };
+
+ // Populate block arguments in clause name alphabetical order to match
+ // expected order by the BlockArgOpenMPOpInterface.
+ extractTypeLoc(args.inReduction.vars);
+ extractTypeLoc(args.map.vars);
+ extractTypeLoc(args.priv.vars);
+ extractTypeLoc(args.reduction.vars);
+ extractTypeLoc(args.taskReduction.vars);
+ extractTypeLoc(args.useDeviceAddr.vars);
+ extractTypeLoc(args.useDevicePtr.vars);
+
+ return builder.createBlock(®ion, {}, types, locs);
+}
+} // namespace Fortran::openmp::common
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index a2779213a1a15a..48b499cae2681c 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -19,6 +19,7 @@
#include "DirectivesCommon.h"
#include "ReductionProcessor.h"
#include "Utils.h"
+#include "flang/Common/OpenMP-utils.h"
#include "flang/Common/idioms.h"
#include "flang/Lower/Bridge.h"
#include "flang/Lower/ConvertExpr.h"
@@ -41,57 +42,12 @@
#include "llvm/Frontend/OpenMP/OMPConstants.h"
using namespace Fortran::lower::omp;
+using namespace Fortran::openmp::common;
//===----------------------------------------------------------------------===//
// Code generation helper functions
//===----------------------------------------------------------------------===//
-namespace {
-/// Structure holding the information needed to create and bind entry block
-/// arguments associated to a single clause.
-struct EntryBlockArgsEntry {
- llvm::ArrayRef<const semantics::Symbol *> syms;
- llvm::ArrayRef<mlir::Value> vars;
-
- bool isValid() const {
- // This check allows specifying a smaller number of symbols than values
- // because in some case cases a single symbol generates multiple block
- // arguments.
- return syms.size() <= vars.size();
- }
-};
-
-/// Structure holding the information needed to create and bind entry block
-/// arguments associated to all clauses that can define them.
-struct EntryBlockArgs {
- EntryBlockArgsEntry inReduction;
- EntryBlockArgsEntry map;
- EntryBlockArgsEntry priv;
- EntryBlockArgsEntry reduction;
- EntryBlockArgsEntry taskReduction;
- EntryBlockArgsEntry useDeviceAddr;
- EntryBlockArgsEntry useDevicePtr;
-
- bool isValid() const {
- return inReduction.isValid() && map.isValid() && priv.isValid() &&
- reduction.isValid() && taskReduction.isValid() &&
- useDeviceAddr.isValid() && useDevicePtr.isValid();
- }
-
- auto getSyms() const {
- return llvm::concat<const semantics::Symbol *const>(
- inReduction.syms, map.syms, priv.syms, reduction.syms,
- taskReduction.syms, useDeviceAddr.syms, useDevicePtr.syms);
- }
-
- auto getVars() const {
- return llvm::concat<const mlir::Value>(
- inReduction.vars, map.vars, priv.vars, reduction.vars,
- taskReduction.vars, useDeviceAddr.vars, useDevicePtr.vars);
- }
-};
-} // namespace
-
static void genOMPDispatch(lower::AbstractConverter &converter,
lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx,
@@ -623,50 +579,6 @@ static void genLoopVars(
firOpBuilder.setInsertionPointAfter(storeOp);
}
-/// Create an entry block for the given region, including the clause-defined
-/// arguments specified.
-///
-/// \param [in] converter - PFT to MLIR conversion interface.
-/// \param [in] args - entry block arguments information for the given
-/// operation.
-/// \param [in] region - Empty region in which to create the entry block.
-static mlir::Block *genEntryBlock(lower::AbstractConverter &converter,
- const EntryBlockArgs &args,
- mlir::Region ®ion) {
- assert(args.isValid() && "invalid args");
- assert(region.empty() && "non-empty region");
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
- llvm::SmallVector<mlir::Type> types;
- llvm::SmallVector<mlir::Location> locs;
- unsigned numVars = args.inReduction.vars.size() + args.map.vars.size() +
- args.priv.vars.size() + args.reduction.vars.size() +
- args.taskReduction.vars.size() +
- args.useDeviceAddr.vars.size() +
- args.useDevicePtr.vars.size();
- types.reserve(numVars);
- locs.reserve(numVars);
-
- auto extractTypeLoc = [&types, &locs](llvm::ArrayRef<mlir::Value> vals) {
- llvm::transform(vals, std::back_inserter(types),
- [](mlir::Value v) { return v.getType(); });
- llvm::transform(vals, std::back_inserter(locs),
- [](mlir::Value v) { return v.getLoc(); });
- };
-
- // Populate block arguments in clause name alphabetical order to match
- // expected order by the BlockArgOpenMPOpInterface.
- extractTypeLoc(args.inReduction.vars);
- extractTypeLoc(args.map.vars);
- extractTypeLoc(args.priv.vars);
- extractTypeLoc(args.reduction.vars);
- extractTypeLoc(args.taskReduction.vars);
- extractTypeLoc(args.useDeviceAddr.vars);
- extractTypeLoc(args.useDevicePtr.vars);
-
- return firOpBuilder.createBlock(®ion, {}, types, locs);
-}
-
static void
markDeclareTarget(mlir::Operation *op, lower::AbstractConverter &converter,
mlir::omp::DeclareTargetCaptureClause captureClause,
@@ -919,7 +831,7 @@ static void genBodyOfTargetDataOp(
ConstructQueue::const_iterator item) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
- genEntryBlock(converter, args, dataOp.getRegion());
+ genEntryBlock(converter.getFirOpBuilder(), args, dataOp.getRegion());
bindEntryBlockArgs(converter, dataOp, args);
// Insert dummy instruction to remember the insertion position. The
@@ -996,7 +908,8 @@ static void genBodyOfTargetOp(
auto argIface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(*targetOp);
mlir::Region ®ion = targetOp.getRegion();
- mlir::Block *entryBlock = genEntryBlock(converter, args, region);
+ mlir::Block *entryBlock =
+ genEntryBlock(converter.getFirOpBuilder(), args, region);
bindEntryBlockArgs(converter, targetOp, args);
// Check if cloning the bounds introduced any dependency on the outer region.
@@ -1122,7 +1035,7 @@ static OpTy genWrapperOp(lower::AbstractConverter &converter,
auto op = firOpBuilder.create<OpTy>(loc, clauseOps);
// Create entry block with arguments.
- genEntryBlock(converter, args, op.getRegion());
+ genEntryBlock(converter.getFirOpBuilder(), args, op.getRegion());
return op;
}
@@ -1588,7 +1501,7 @@ genParallelOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
const EntryBlockArgs &args, DataSharingProcessor *dsp,
bool isComposite = false) {
auto genRegionEntryCB = [&](mlir::Operation *op) {
- genEntryBlock(converter, args, op->getRegion(0));
+ genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0));
bindEntryBlockArgs(
converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
return llvm::to_vector(args.getSyms());
@@ -1661,12 +1574,12 @@ genSectionsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
args.reduction.syms = reductionSyms;
args.reduction.vars = clauseOps.reductionVars;
- genEntryBlock(converter, args, sectionsOp.getRegion());
+ genEntryBlock(converter.getFirOpBuilder(), args, sectionsOp.getRegion());
mlir::Operation *terminator =
lower::genOpenMPTerminator(builder, sectionsOp, loc);
auto genRegionEntryCB = [&](mlir::Operation *op) {
- genEntryBlock(converter, args, op->getRegion(0));
+ genEntryBlock(converter.getFirOpBuilder(), args, op->getRegion(0));
bindEntryBlockArgs(
converter, llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op), args);
return llvm::to_vector(args.getSyms());
@@ -1989,7 +1902,7 @@ genTaskOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
taskArgs.priv.vars = clauseOps.privateVars;
auto genRegionEntryCB = [&](mlir::Operation *op) {
- genEntryBlock(converter, taskArgs, op->getRegion(0));
+ genEntryBlock(converter.getFirOpBuilder(), taskArgs, op->getRegion(0));
bindEntryBlockArgs(converter,
llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op),
taskArgs);
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index b1e0dbf6e707e5..51ecbe1a664f92 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -2,6 +2,7 @@ get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
add_flang_library(FlangOpenMPTransforms
FunctionFiltering.cpp
+ GenericLoopConversion.cpp
MapsForPrivatizedSymbols.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp
@@ -25,4 +26,5 @@ add_flang_library(FlangOpenMPTransforms
HLFIRDialect
MLIRIR
MLIRPass
+ MLIRTransformUtils
)
diff --git a/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
new file mode 100644
index 00000000000000..cf001a203703e1
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/GenericLoopConversion.cpp
@@ -0,0 +1,168 @@
+//===- GenericLoopConversion.cpp ------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "flang/Common/OpenMP-utils.h"
+#include "flang/Optimizer/Dialect/FIRDialect.h"
+#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "flang/Semantics/symbol.h"
+
+#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Math/IR/Math.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Transforms/DialectConversion.h"
+
+#include <memory>
+
+namespace flangomp {
+#define GEN_PASS_DEF_GENERICLOOPCONVERSIONPASS
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+namespace {
+
+/// A conversion pattern to handle various combined forms of `omp.loop`. For how
+/// combined/composite directive are handled see:
+/// https://discourse.llvm.org/t/rfc-representing-combined-composite-constructs-in-the-openmp-dialect/76986.
+class GenericLoopConversionPattern
+ : public mlir::OpConversionPattern<mlir::omp::LoopOp> {
+public:
+ enum class GenericLoopCombinedInfo {
+ None,
+ TargetTeamsLoop,
+ TargetParallelLoop
+ };
+
+ using mlir::OpConversionPattern<mlir::omp::LoopOp>::OpConversionPattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(mlir::omp::LoopOp loopOp, OpAdaptor adaptor,
+ mlir::ConversionPatternRewriter &rewriter) const override {
+ assert(isLoopConversionSupported(loopOp));
+
+ rewriteToDistributeParallelDo(loopOp, rewriter);
+ rewriter.eraseOp(loopOp);
+ return mlir::success();
+ }
+
+ static bool isLoopConversionSupported(mlir::omp::LoopOp loopOp) {
+ GenericLoopCombinedInfo combinedInfo = findGenericLoopCombineInfo(loopOp);
+
+ // TODO Support standalone `loop` ops and other forms of combined `loop` op
+ // nests.
+ if (combinedInfo != GenericLoopCombinedInfo::TargetTeamsLoop)
+ return false;
+
+ // TODO Support other clauses.
+ if (loopOp.getBindKind() || loopOp.getOrder() ||
+ !loopOp.getReductionVars().empty())
+ return false;
+
+ // TODO For `target teams loop`, check similar constrains to what is checked
+ // by `TeamsLoopChecker` in SemaOpenMP.cpp.
+ return true;
+ }
+
+private:
+ static GenericLoopCombinedInfo
+ findGenericLoopCombineInfo(mlir::omp::LoopOp loopOp) {
+ mlir::Operation *parentOp = loopOp->getParentOp();
+ GenericLoopCombinedInfo result = GenericLoopCombinedInfo::None;
+
+ if (auto teamsOp = mlir::dyn_cast_if_present<mlir::omp::TeamsOp>(parentOp))
+ if (mlir::isa<mlir::omp::TargetOp>(teamsOp->getParentOp()))
+ result = GenericLoopCombinedInfo::TargetTeamsLoop;
+
+ if (auto parallelOp =
+ mlir::dyn_cast_if_present<mlir::omp::ParallelOp>(parentOp))
+ if (mlir::isa<mlir::omp::TargetOp>(parallelOp->getParentOp()))
+ result = GenericLoopCombinedInfo::TargetParallelLoop;
+
+ return result;
+ }
+
+ void rewriteToDistributeParallelDo(
+ mlir::omp::LoopOp loopOp,
+ mlir::ConversionPatternRewriter &rewriter) const {
+ mlir::omp::ParallelOperands parallelClauseOps;
+ parallelClauseOps.privateVars = loopOp.getPrivateVars();
+
+ auto privateSyms = loopOp.getPrivateSyms();
+ if (privateSyms)
+ parallelClauseOps.privateSyms.assign(privateSyms->begin(),
+ privateSyms->end());
+
+ Fortran::openmp::common::EntryBlockArgs parallelArgs;
+ parallelArgs.priv.vars = parallelClauseOps.privateVars;
+
+ auto parallelOp = rewriter.create<mlir::omp::ParallelOp>(loopOp.getLoc(),
+ parallelClauseOps);
+ mlir::Block *parallelBlock =
+ genEntryBlock(rewriter, parallelArgs, parallelOp.getRegion());
+ parallelOp.setComposite(true);
+ rewriter.setInsertionPoint(
+ rewriter.create<mlir::omp::TerminatorOp>(loopOp.getLoc()));
+
+ mlir::omp::DistributeOperands distributeClauseOps;
+ auto distributeOp = rewriter.create<mlir::omp::DistributeOp>(
+ loopOp.getLoc(), distributeClauseOps);
+ distributeOp.setComposite(true);
+ rewriter.createBlock(&distributeOp.getRegion());
+
+ mlir::omp::WsloopOperands wsloopClauseOps;
+ auto wsloopOp =
+ rewriter.create<mlir::omp::WsloopOp>(loopOp.getLoc(), wsloopClauseOps);
+ wsloopOp.setComposite(true);
+ rewriter.createBlock(&wsloopOp.getRegion());
+
+ mlir::IRMapping mapper;
+ mlir::Block &loopBlock = *loopOp.getRegion().begin();
+
+ for (auto [loopOpArg, parallelOpArg] : llvm::zip_equal(
+ loopBlock.getArguments(), parallelBlock->getArguments()))
+ mapper.map(loopOpArg, parallelOpArg);
+
+ rewriter.clone(*loopOp.begin(), mapper);
+ }
+};
+
+class GenericLoopConversionPass
+ : public flangomp::impl::GenericLoopConversionPassBase<
+ GenericLoopConversionPass> {
+public:
+ GenericLoopConversionPass() = default;
+
+ void runOnOperation() override {
+ mlir::func::FuncOp func = getOperation();
+
+ if (func.isDeclaration()) {
+ return;
+ }
+
+ mlir::MLIRContext *context = &getContext();
+ mlir::RewritePatternSet patterns(context);
+ patterns.insert<GenericLoopConversionPattern>(context);
+ mlir::ConversionTarget target(*context);
+
+ target.markUnknownOpDynamicallyLegal(
+ [](mlir::Operation *) { return true; });
+ target.addDynamicallyLegalOp<mlir::omp::LoopOp>(
+ [](mlir::omp::LoopOp loopOp) {
+ return !GenericLoopConversionPattern::isLoopConversionSupported(
+ loopOp);
+ });
+
+ if (mlir::failed(mlir::applyFullConversion(getOperation(), target,
+ std::move(patterns)))) {
+ mlir::emitError(func.getLoc(), "error in converting `omp.loop` op");
+ signalPassFailure();
+ }
+ }
+};
+} // namespace
diff --git a/flang/lib/Optimizer/Passes/Pipelines.cpp b/flang/lib/Optimizer/Passes/Pipelines.cpp
index 31af3531641dda..1568165bcd64a4 100644
--- a/flang/lib/Optimizer/Passes/Pipelines.cpp
+++ b/flang/lib/Optimizer/Passes/Pipelines.cpp
@@ -247,6 +247,7 @@ void createOpenMPFIRPassPipeline(mlir::PassManager &pm, bool isTargetDevice) {
pm.addPass(flangomp::createMapInfoFinalizationPass());
pm.addPass(flangomp::createMapsForPrivatizedSymbolsPass());
pm.addPass(flangomp::createMarkDeclareTargetPass());
+ pm.addPass(flangomp::createGenericLoopConversionPass());
if (isTargetDevice)
pm.addPass(flangomp::createFunctionFilteringPass());
}
diff --git a/flang/test/Lower/OpenMP/generic-loop-rewriting.f90 b/flang/test/Lower/OpenMP/generic-loop-rewriting.f90
new file mode 100644
index 00000000000000..fa26425356dd90
--- /dev/null
+++ b/flang/test/Lower/OpenMP/generic-loop-rewriting.f90
@@ -0,0 +1,41 @@
+!RUN: %flang_fc1 -emit-hlfir -fopenmp %s -o - | FileCheck %s
+
+subroutine target_teams_loop
+ implicit none
+ integer :: x, i
+
+ !$omp target teams loop
+ do i = 0, 10
+ x = x + i
+ end do
+end subroutine target_teams_loop
+
+!CHECK-LABEL: func.func @_QPtarget_teams_loop
+!CHECK: omp.target map_entries(
+!CHECK-SAME: %{{.*}} -> %[[I_ARG:[^[:space:]]+]],
+!CHECK-SAME: %{{.*}} -> %[[X_ARG:[^[:space:]]+]] : {{.*}}) {
+
+!CHECK: %[[I_DECL:.*]]:2 = hlfir.declare %[[I_ARG]]
+!CHECK: %[[X_DECL:.*]]:2 = hlfir.declare %[[X_ARG]]
+
+!CHECK: omp.teams {
+
+!CHECK: %[[LB:.*]] = arith.constant 0 : i32
+!CHECK: %[[UB:.*]] = arith.constant 10 : i32
+!CHECK: %[[STEP:.*]] = arith.constant 1 : i32
+
+!CHECK: omp.parallel private(@{{.*}} %[[I_DECL]]#0
+!CHECK-SAME: -> %[[I_PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>) {
+!CHECK: omp.distribute {
+!CHECK: omp.wsloop {
+
+!CHECK: omp.loop_nest (%{{.*}}) : i32 =
+!CHECK-SAME: (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) {
+!CHECK: %[[I_PRIV_DECL:.*]]:2 = hlfir.declare %[[I_PRIV_ARG]]
+!CHECK: fir.store %{{.*}} to %[[I_PRIV_DECL]]#1 : !fir.ref<i32>
+!CHECK: }
+!CHECK: }
+!CHECK: }
+!CHECK: }
+!CHECK: }
+!CHECK: }
diff --git a/flang/test/Transforms/generic-loop-rewriting.mlir b/flang/test/Transforms/generic-loop-rewriting.mlir
new file mode 100644
index 00000000000000..a18ea9853602ac
--- /dev/null
+++ b/flang/test/Transforms/generic-loop-rewriting.mlir
@@ -0,0 +1,57 @@
+// RUN: fir-opt --omp-generic-loop-conversion %s | FileCheck %s
+
+omp.private {type = private} @_QFtarget_teams_loopEi_private_ref_i32 : !fir.ref<i32> alloc {
+^bb0(%arg0: !fir.ref<i32>):
+ omp.yield(%arg0 : !fir.ref<i32>)
+}
+
+func.func @_QPtarget_teams_loop() {
+ %i = fir.alloca i32
+ %i_map = omp.map.info var_ptr(%i : !fir.ref<i32>, i32) map_clauses(implicit, exit_release_or_enter_alloc) capture(ByCopy) -> !fir.ref<i32> {name = "i"}
+ omp.target map_entries(%i_map -> %arg0 : !fir.ref<i32>) {
+ omp.teams {
+ %c0 = arith.constant 0 : i32
+ %c10 = arith.constant 10 : i32
+ %c1 = arith.constant 1 : i32
+ omp.loop private(@_QFtarget_teams_loopEi_private_ref_i32 %arg0 -> %arg2 : !fir.ref<i32>) {
+ omp.loop_nest (%arg3) : i32 = (%c0) to (%c10) inclusive step (%c1) {
+ fir.store %arg3 to %arg2 : !fir.ref<i32>
+ omp.yield
+ }
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @_QPtarget_teams_loop
+// CHECK: omp.target map_entries(
+// CHECK-SAME: %{{.*}} -> %[[I_ARG:[^[:space:]]+]] : {{.*}}) {
+//
+// CHECK: omp.teams {
+//
+// TODO we probably need to move the `loop_nest` bounds ops from the `teams`
+// region to the `parallel` region to avoid making these values `shared`. We can
+// find the backward slices of these bounds that are within the `teams` region
+// and move these slices to the `parallel` op.
+
+// CHECK: %[[LB:.*]] = arith.constant 0 : i32
+// CHECK: %[[UB:.*]] = arith.constant 10 : i32
+// CHECK: %[[STEP:.*]] = arith.constant 1 : i32
+//
+// CHECK: omp.parallel private(@{{.*}} %[[I_ARG]]
+// CHECK-SAME: -> %[[I_PRIV_ARG:[^[:space:]]+]] : !fir.ref<i32>) {
+// CHECK: omp.distribute {
+// CHECK: omp.wsloop {
+//
+// CHECK: omp.loop_nest (%{{.*}}) : i32 =
+// CHECK-SAME: (%[[LB]]) to (%[[UB]]) inclusive step (%[[STEP]]) {
+// CHECK: fir.store %{{.*}} to %[[I_PRIV_ARG]] : !fir.ref<i32>
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
+// CHECK: }
More information about the flang-commits
mailing list