[llvm-branch-commits] [flang] [flang] Lower omp.workshare to other omp constructs (PR #101446)
Ivan R. Ivanov via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Wed Jul 31 19:23:59 PDT 2024
https://github.com/ivanradanov created https://github.com/llvm/llvm-project/pull/101446
4/4
There are two points which need some discussion in this PR:
1. We need to make a value computed in a omp.single accessible in all threads of the omp.parallel region. This is achieved by allocating temporary memory outside the omp.parallel and atoring that in the omp.single and then reloading it from all threads. However, from reading the standard I dont think we are guaranteed that the workshare is nested in the omp.parallel so there could be a omp.parallel { func.call @contains_workshare }, then we would not be able to access the omp.parallel. So I think adding support in the runtime to be able to yield a value from a omp.single could be the fix to this.
2. For the temporary allocations above not all types are supported by fir.alloca, so I need to use llvm.alloca and unrealized_cast to be able to allocate a temporary for a fir.ref type. This too can be fixed by introducing yielding from omp.single
>From 003568d028b9d7f0323f31f8717527fc52c93c6f Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:44:31 +0900
Subject: [PATCH] [flang] Lower omp.workshare to other omp constructs
---
flang/include/flang/Optimizer/CMakeLists.txt | 1 +
.../flang/Optimizer/OpenMP/CMakeLists.txt | 4 +
flang/include/flang/Optimizer/OpenMP/Passes.h | 30 ++
.../include/flang/Optimizer/OpenMP/Passes.td | 18 ++
flang/include/flang/Tools/CLOptions.inc | 2 +
flang/lib/Frontend/CMakeLists.txt | 1 +
flang/lib/Optimizer/Builder/HLFIRTools.cpp | 21 +-
flang/lib/Optimizer/CMakeLists.txt | 1 +
.../HLFIR/Transforms/BufferizeHLFIR.cpp | 6 +-
flang/lib/Optimizer/OpenMP/CMakeLists.txt | 26 ++
flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 260 ++++++++++++++++++
flang/test/HLFIR/bufferize-workshare.fir | 58 ++++
.../Transforms/OpenMP/lower-workshare.mlir | 81 ++++++
flang/tools/bbc/CMakeLists.txt | 1 +
flang/tools/fir-opt/CMakeLists.txt | 1 +
flang/tools/fir-opt/fir-opt.cpp | 2 +
flang/tools/tco/CMakeLists.txt | 1 +
17 files changed, 505 insertions(+), 9 deletions(-)
create mode 100644 flang/include/flang/Optimizer/OpenMP/CMakeLists.txt
create mode 100644 flang/include/flang/Optimizer/OpenMP/Passes.h
create mode 100644 flang/include/flang/Optimizer/OpenMP/Passes.td
create mode 100644 flang/lib/Optimizer/OpenMP/CMakeLists.txt
create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
create mode 100644 flang/test/HLFIR/bufferize-workshare.fir
create mode 100644 flang/test/Transforms/OpenMP/lower-workshare.mlir
diff --git a/flang/include/flang/Optimizer/CMakeLists.txt b/flang/include/flang/Optimizer/CMakeLists.txt
index 89e43a9ee8d62..3336ac935e101 100644
--- a/flang/include/flang/Optimizer/CMakeLists.txt
+++ b/flang/include/flang/Optimizer/CMakeLists.txt
@@ -2,3 +2,4 @@ add_subdirectory(CodeGen)
add_subdirectory(Dialect)
add_subdirectory(HLFIR)
add_subdirectory(Transforms)
+add_subdirectory(OpenMP)
diff --git a/flang/include/flang/Optimizer/OpenMP/CMakeLists.txt b/flang/include/flang/Optimizer/OpenMP/CMakeLists.txt
new file mode 100644
index 0000000000000..d59573f0f7fd9
--- /dev/null
+++ b/flang/include/flang/Optimizer/OpenMP/CMakeLists.txt
@@ -0,0 +1,4 @@
+set(LLVM_TARGET_DEFINITIONS Passes.td)
+mlir_tablegen(Passes.h.inc -gen-pass-decls -name FlangOpenMP)
+
+add_public_tablegen_target(FlangOpenMPPassesIncGen)
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.h b/flang/include/flang/Optimizer/OpenMP/Passes.h
new file mode 100644
index 0000000000000..95a05b3005073
--- /dev/null
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.h
@@ -0,0 +1,30 @@
+//===- Passes.h - OpenMP pass entry points ----------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This header declares OpenMP pass entry points.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_OPTIMIZER_OPENMP_PASSES_H
+#define FORTRAN_OPTIMIZER_OPENMP_PASSES_H
+
+#include "mlir/Dialect/Func/IR/FuncOps.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Pass/PassRegistry.h"
+#include <memory>
+
+namespace flangomp {
+#define GEN_PASS_DECL
+#define GEN_PASS_REGISTRATION
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+
+bool shouldUseWorkshareLowering(mlir::Operation *op);
+
+} // namespace flangomp
+
+#endif // FORTRAN_OPTIMIZER_OPENMP_PASSES_H
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
new file mode 100644
index 0000000000000..6f636ec1df616
--- /dev/null
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -0,0 +1,18 @@
+//===-- Passes.td - HLFIR pass definition file -------------*- tablegen -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef FORTRAN_DIALECT_OPENMP_PASSES
+#define FORTRAN_DIALECT_OPENMP_PASSES
+
+include "mlir/Pass/PassBase.td"
+
+def LowerWorkshare : Pass<"lower-workshare"> {
+ let summary = "Lower workshare construct";
+}
+
+#endif //FORTRAN_DIALECT_OPENMP_PASSES
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 7df5044949463..594369fc2ffe5 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -17,6 +17,7 @@
#include "mlir/Transforms/Passes.h"
#include "flang/Optimizer/CodeGen/CodeGen.h"
#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
#include "flang/Optimizer/Transforms/Passes.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Support/CommandLine.h"
@@ -344,6 +345,7 @@ inline void createHLFIRToFIRPassPipeline(
pm.addPass(hlfir::createLowerHLFIRIntrinsics());
pm.addPass(hlfir::createBufferizeHLFIR());
pm.addPass(hlfir::createConvertHLFIRtoFIR());
+ pm.addPass(flangomp::createLowerWorkshare());
}
/// Create a pass pipeline for handling certain OpenMP transformations needed
diff --git a/flang/lib/Frontend/CMakeLists.txt b/flang/lib/Frontend/CMakeLists.txt
index c20b9096aff49..ecdcc73d61ec1 100644
--- a/flang/lib/Frontend/CMakeLists.txt
+++ b/flang/lib/Frontend/CMakeLists.txt
@@ -38,6 +38,7 @@ add_flang_library(flangFrontend
FIRTransforms
HLFIRDialect
HLFIRTransforms
+ FlangOpenMPTransforms
MLIRTransforms
MLIRBuiltinToLLVMIRTranslation
MLIRLLVMToLLVMIRTranslation
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 5406e5076d55c..cd07cb741eb4b 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -865,12 +865,11 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
// Build loop nest from column to row.
auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
mlir::Type indexType = builder.getIndexType();
- unsigned dim = extents.size() - 1;
-
if (emitWsLoop) {
- auto wsloop = builder.create<mlir::omp::WsloopOp>(loc, mlir::ArrayRef<mlir::NamedAttribute>());
+ auto wsloop = builder.create<mlir::omp::WsloopOp>(
+ loc, mlir::ArrayRef<mlir::NamedAttribute>());
loopNest.outerOp = wsloop;
- builder.createBlock(wsloop.getBody());
+ builder.createBlock(&wsloop.getRegion());
mlir::omp::LoopNestOperands lnops;
lnops.loopInclusive = builder.getUnitAttr();
for (auto extent : llvm::reverse(extents)) {
@@ -879,13 +878,21 @@ hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
lnops.loopSteps.push_back(one);
}
auto lnOp = builder.create<mlir::omp::LoopNestOp>(loc, lnops);
+ builder.create<mlir::omp::TerminatorOp>(loc);
+ mlir::Block *block = builder.createBlock(&lnOp.getRegion());
+ for (auto extent : llvm::reverse(extents))
+ block->addArgument(extent.getType(), extent.getLoc());
+ loopNest.body = block;
builder.create<mlir::omp::YieldOp>(loc);
- builder.createBlock(&lnOp.getRegion().front());
- builder.create<mlir::omp::YieldOp>(loc);
+ for (unsigned dim = 0; dim < extents.size(); dim++)
+ loopNest.oneBasedIndices[extents.size() - dim - 1] =
+ lnOp.getRegion().front().getArgument(dim);
} else {
+ unsigned dim = extents.size() - 1;
for (auto extent : llvm::reverse(extents)) {
auto ub = builder.createConvert(loc, indexType, extent);
- auto doLoop = builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
+ auto doLoop =
+ builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
loopNest.body = doLoop.getBody();
builder.setInsertionPointToStart(loopNest.body);
// Reverse the indices so they are in column-major order.
diff --git a/flang/lib/Optimizer/CMakeLists.txt b/flang/lib/Optimizer/CMakeLists.txt
index 4a602162ed2b7..dd153ac33c0fb 100644
--- a/flang/lib/Optimizer/CMakeLists.txt
+++ b/flang/lib/Optimizer/CMakeLists.txt
@@ -5,3 +5,4 @@ add_subdirectory(HLFIR)
add_subdirectory(Support)
add_subdirectory(Transforms)
add_subdirectory(Analysis)
+add_subdirectory(OpenMP)
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index b608677c52631..1848dbe2c7a2c 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -26,12 +26,13 @@
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/ADT/TypeSwitch.h"
namespace hlfir {
@@ -792,7 +793,8 @@ struct ElementalOpConversion
// Generate a loop nest looping around the fir.elemental shape and clone
// fir.elemental region inside the inner loop.
hlfir::LoopNest loopNest =
- hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
+ hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered(),
+ flangomp::shouldUseWorkshareLowering(elemental));
auto insPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(loopNest.body);
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
new file mode 100644
index 0000000000000..74419327d76d0
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -0,0 +1,26 @@
+get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
+
+add_flang_library(FlangOpenMPTransforms
+ LowerWorkshare.cpp
+
+ DEPENDS
+ FIRDialect
+ FlangOpenMPPassesIncGen
+ ${dialect_libs}
+
+ LINK_LIBS
+ FIRAnalysis
+ FIRDialect
+ FIRBuilder
+ FIRDialectSupport
+ FIRSupport
+ FIRTransforms
+ HLFIRDialect
+ MLIRIR
+ ${dialect_libs}
+
+ LINK_COMPONENTS
+ AsmParser
+ AsmPrinter
+ Remarks
+)
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
new file mode 100644
index 0000000000000..ddcfc1f345d50
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -0,0 +1,260 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+// Lower omp workshare construct.
+//===----------------------------------------------------------------------===//
+
+#include "flang/Optimizer/Dialect/FIROps.h"
+#include "flang/Optimizer/Dialect/FIRType.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPClauseOperands.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/BuiltinOps.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OpDefinition.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/Support/LLVM.h"
+#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/STLExtras.h"
+#include "llvm/ADT/iterator_range.h"
+
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKSHARE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workshare"
+
+using namespace mlir;
+
+namespace flangomp {
+bool shouldUseWorkshareLowering(Operation *op) {
+ auto workshare = dyn_cast<omp::WorkshareOp>(op->getParentOp());
+ if (!workshare)
+ return false;
+ return workshare->getParentOfType<omp::ParallelOp>();
+}
+} // namespace flangomp
+
+namespace {
+
+struct SingleRegion {
+ Block::iterator begin, end;
+};
+
+static bool isSupportedByFirAlloca(Type ty) {
+ return !isa<fir::ReferenceType>(ty);
+}
+
+static bool isSafeToParallelize(Operation *op) {
+ if (isa<fir::DeclareOp>(op))
+ return true;
+
+ llvm::SmallVector<MemoryEffects::EffectInstance> effects;
+ MemoryEffectOpInterface interface = dyn_cast<MemoryEffectOpInterface>(op);
+ if (!interface) {
+ return false;
+ }
+ interface.getEffects(effects);
+ if (effects.empty())
+ return true;
+
+ return false;
+}
+
+/// Lowers workshare to a sequence of single-thread regions and parallel loops
+///
+/// For example:
+///
+/// omp.workshare {
+/// %a = fir.allocmem
+/// omp.wsloop {}
+/// fir.call Assign %b %a
+/// fir.freemem %a
+/// }
+///
+/// becomes
+///
+/// omp.single {
+/// %a = fir.allocmem
+/// fir.store %a %tmp
+/// }
+/// %a_reloaded = fir.load %tmp
+/// omp.wsloop {}
+/// omp.single {
+/// fir.call Assign %b %a_reloaded
+/// fir.freemem %a_reloaded
+/// }
+///
+/// Note that we allocate temporary memory for values in omp.single's which need
+/// to be accessed in all threads in the closest omp.parallel
+///
+/// TODO currently we need to be able to access the encompassing omp.parallel so
+/// that we can allocate temporaries accessible by all threads outside of it.
+/// In case we do not find it, we fall back to converting the omp.workshare to
+/// omp.single.
+/// To better handle this we should probably enable yielding values out of an
+/// omp.single which will be supported by the omp runtime.
+void lowerWorkshare(mlir::omp::WorkshareOp wsOp) {
+ assert(wsOp.getRegion().getBlocks().size() == 1);
+
+ Location loc = wsOp->getLoc();
+
+ omp::ParallelOp parallelOp = wsOp->getParentOfType<omp::ParallelOp>();
+ if (!parallelOp) {
+ wsOp.emitWarning("cannot handle workshare, converting to single");
+ Operation *terminator = wsOp.getRegion().front().getTerminator();
+ wsOp->getBlock()->getOperations().splice(
+ wsOp->getIterator(), wsOp.getRegion().front().getOperations());
+ terminator->erase();
+ return;
+ }
+
+ OpBuilder allocBuilder(parallelOp);
+ OpBuilder rootBuilder(wsOp);
+ IRMapping rootMapping;
+
+ omp::SingleOp singleOp = nullptr;
+
+ auto mapReloadedValue = [&](Value v, OpBuilder singleBuilder,
+ IRMapping singleMapping) {
+ if (auto reloaded = rootMapping.lookupOrNull(v))
+ return;
+ Type llvmPtrTy = LLVM::LLVMPointerType::get(allocBuilder.getContext());
+ Type ty = v.getType();
+ Value alloc, reloaded;
+ if (isSupportedByFirAlloca(ty)) {
+ alloc = allocBuilder.create<fir::AllocaOp>(loc, ty);
+ singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+ reloaded = rootBuilder.create<fir::LoadOp>(loc, ty, alloc);
+ } else {
+ auto one = allocBuilder.create<LLVM::ConstantOp>(
+ loc, allocBuilder.getI32Type(), 1);
+ alloc =
+ allocBuilder.create<LLVM::AllocaOp>(loc, llvmPtrTy, llvmPtrTy, one);
+ Value toStore = singleBuilder
+ .create<UnrealizedConversionCastOp>(
+ loc, llvmPtrTy, singleMapping.lookup(v))
+ .getResult(0);
+ singleBuilder.create<LLVM::StoreOp>(loc, toStore, alloc);
+ reloaded = rootBuilder.create<LLVM::LoadOp>(loc, llvmPtrTy, alloc);
+ reloaded =
+ rootBuilder.create<UnrealizedConversionCastOp>(loc, ty, reloaded)
+ .getResult(0);
+ }
+ rootMapping.map(v, reloaded);
+ };
+
+ auto moveToSingle = [&](SingleRegion sr, OpBuilder singleBuilder) {
+ IRMapping singleMapping = rootMapping;
+
+ for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
+ singleBuilder.clone(op, singleMapping);
+ if (isSafeToParallelize(&op)) {
+ rootBuilder.clone(op, rootMapping);
+ } else {
+ // Prepare reloaded values for results of operations that cannot be
+ // safely parallelized and which are used after the region `sr`
+ for (auto res : op.getResults()) {
+ for (auto &use : res.getUses()) {
+ Operation *user = use.getOwner();
+ while (user->getParentOp() != wsOp)
+ user = user->getParentOp();
+ if (!user->isBeforeInBlock(&*sr.end)) {
+ // We need to reload
+ mapReloadedValue(use.get(), singleBuilder, singleMapping);
+ }
+ }
+ }
+ }
+ }
+ singleBuilder.create<omp::TerminatorOp>(loc);
+ };
+
+ Block *wsBlock = &wsOp.getRegion().front();
+ assert(wsBlock->getTerminator()->getNumOperands() == 0);
+ Operation *terminator = wsBlock->getTerminator();
+
+ SmallVector<std::variant<SingleRegion, omp::WsloopOp>> regions;
+
+ auto it = wsBlock->begin();
+ auto getSingleRegion = [&]() {
+ if (&*it == terminator)
+ return false;
+ if (auto pop = dyn_cast<omp::WsloopOp>(&*it)) {
+ regions.push_back(pop);
+ it++;
+ return true;
+ }
+ SingleRegion sr;
+ sr.begin = it;
+ while (&*it != terminator && !isa<omp::WsloopOp>(&*it))
+ it++;
+ sr.end = it;
+ assert(sr.begin != sr.end);
+ regions.push_back(sr);
+ return true;
+ };
+ while (getSingleRegion())
+ ;
+
+ for (auto [i, loopOrSingle] : llvm::enumerate(regions)) {
+ bool isLast = i + 1 == regions.size();
+ if (std::holds_alternative<SingleRegion>(loopOrSingle)) {
+ omp::SingleOperands singleOperands;
+ if (isLast)
+ singleOperands.nowait = rootBuilder.getUnitAttr();
+ singleOp = rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+ OpBuilder singleBuilder(singleOp);
+ singleBuilder.createBlock(&singleOp.getRegion());
+ moveToSingle(std::get<SingleRegion>(loopOrSingle), singleBuilder);
+ } else {
+ rootBuilder.clone(*std::get<omp::WsloopOp>(loopOrSingle), rootMapping);
+ if (!isLast)
+ rootBuilder.create<omp::BarrierOp>(loc);
+ }
+ }
+
+ if (!wsOp.getNowait())
+ rootBuilder.create<omp::BarrierOp>(loc);
+
+ wsOp->erase();
+
+ return;
+}
+
+class LowerWorksharePass
+ : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
+public:
+ void runOnOperation() override {
+ SmallPtrSet<Operation *, 8> parents;
+ getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
+ Operation *isolatedParent =
+ wsOp->getParentWithTrait<OpTrait::IsIsolatedFromAbove>();
+ parents.insert(isolatedParent);
+
+ lowerWorkshare(wsOp);
+ });
+
+ // Do folding
+ for (Operation *isolatedParent : parents) {
+ RewritePatternSet patterns(&getContext());
+ GreedyRewriteConfig config;
+ // prevent the pattern driver form merging blocks
+ config.enableRegionSimplification =
+ mlir::GreedySimplifyRegionLevel::Disabled;
+ if (failed(applyPatternsAndFoldGreedily(isolatedParent,
+ std::move(patterns), config))) {
+ emitError(isolatedParent->getLoc(), "error in lower workshare\n");
+ signalPassFailure();
+ }
+ }
+ }
+};
+} // namespace
diff --git a/flang/test/HLFIR/bufferize-workshare.fir b/flang/test/HLFIR/bufferize-workshare.fir
new file mode 100644
index 0000000000000..86a2f031478dd
--- /dev/null
+++ b/flang/test/HLFIR/bufferize-workshare.fir
@@ -0,0 +1,58 @@
+// RUN: fir-opt --bufferize-hlfir %s | FileCheck %s
+
+// CHECK-LABEL: func.func @simple(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK: omp.parallel {
+// CHECK: omp.workshare {
+// CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK: %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK: %[[VAL_7:.*]] = arith.constant true
+// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: omp.wsloop {
+// CHECK: omp.loop_nest (%[[VAL_9:.*]]) : index = (%[[VAL_8]]) to (%[[VAL_1]]) inclusive step (%[[VAL_8]]) {
+// CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_9]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
+// CHECK: %[[VAL_12:.*]] = arith.subi %[[VAL_11]], %[[VAL_2]] : i32
+// CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_13]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK: omp.yield
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: %[[VAL_14:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_7]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK: %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_6]]#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK: hlfir.assign %[[VAL_6]]#0 to %[[VAL_4]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK: fir.freemem %[[VAL_6]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func.func @simple(%arg: !fir.ref<!fir.array<42xi32>>) {
+ omp.parallel {
+ omp.workshare {
+ %c42 = arith.constant 42 : index
+ %c1_i32 = arith.constant 1 : i32
+ %shape = fir.shape %c42 : (index) -> !fir.shape<1>
+ %array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+ %elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+ ^bb0(%i: index):
+ %ref = hlfir.designate %array#0 (%i) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+ %val = fir.load %ref : !fir.ref<i32>
+ %sub = arith.subi %val, %c1_i32 : i32
+ hlfir.yield_element %sub : i32
+ }
+ hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+ hlfir.destroy %elemental : !hlfir.expr<42xi32>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
new file mode 100644
index 0000000000000..a8d36443f08bd
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -0,0 +1,81 @@
+// RUN: fir-opt --lower-workshare %s | FileCheck %s
+
+module {
+// CHECK-LABEL: func.func @simple(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK: %[[VAL_1:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_3:.*]] = arith.constant 42 : index
+// CHECK: %[[VAL_4:.*]] = llvm.mlir.constant(1 : i32) : i32
+// CHECK: %[[VAL_5:.*]] = llvm.alloca %[[VAL_4]] x !llvm.ptr : (i32) -> !llvm.ptr
+// CHECK: %[[VAL_6:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK: omp.parallel {
+// CHECK: omp.single {
+// CHECK: %[[VAL_7:.*]] = fir.shape %[[VAL_3]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_7]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK: %[[VAL_9:.*]] = builtin.unrealized_conversion_cast %[[VAL_8]]#0 : !fir.ref<!fir.array<42xi32>> to !llvm.ptr
+// CHECK: llvm.store %[[VAL_9]], %[[VAL_5]] : !llvm.ptr, !llvm.ptr
+// CHECK: %[[VAL_10:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK: %[[VAL_11:.*]]:2 = hlfir.declare %[[VAL_10]](%[[VAL_7]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK: fir.store %[[VAL_11]]#0 to %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: %[[VAL_12:.*]] = llvm.load %[[VAL_5]] : !llvm.ptr -> !llvm.ptr
+// CHECK: %[[VAL_13:.*]] = builtin.unrealized_conversion_cast %[[VAL_12]] : !llvm.ptr to !fir.ref<!fir.array<42xi32>>
+// CHECK: %[[VAL_14:.*]] = fir.load %[[VAL_6]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK: omp.wsloop {
+// CHECK: omp.loop_nest (%[[VAL_15:.*]]) : index = (%[[VAL_1]]) to (%[[VAL_3]]) inclusive step (%[[VAL_1]]) {
+// CHECK: %[[VAL_16:.*]] = hlfir.designate %[[VAL_13]] (%[[VAL_15]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_17:.*]] = fir.load %[[VAL_16]] : !fir.ref<i32>
+// CHECK: %[[VAL_18:.*]] = arith.subi %[[VAL_17]], %[[VAL_2]] : i32
+// CHECK: %[[VAL_19:.*]] = hlfir.designate %[[VAL_14]] (%[[VAL_15]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK: hlfir.assign %[[VAL_18]] to %[[VAL_19]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK: omp.yield
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.barrier
+// CHECK: omp.single nowait {
+// CHECK: hlfir.assign %[[VAL_14]] to %[[VAL_13]] : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK: fir.freemem %[[VAL_14]] : !fir.heap<!fir.array<42xi32>>
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.barrier
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: return
+// CHECK: }
+ func.func @simple(%arg0: !fir.ref<!fir.array<42xi32>>) {
+ omp.parallel {
+ omp.workshare {
+ %c42 = arith.constant 42 : index
+ %c1_i32 = arith.constant 1 : i32
+ %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+ %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+ %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+ %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+ %true = arith.constant true
+ %c1 = arith.constant 1 : index
+ omp.wsloop {
+ omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+ %7 = hlfir.designate %1#0 (%arg1) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+ %8 = fir.load %7 : !fir.ref<i32>
+ %9 = arith.subi %8, %c1_i32 : i32
+ %10 = hlfir.designate %3#0 (%arg1) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+ hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+ omp.yield
+ }
+ omp.terminator
+ }
+ %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+ %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+ %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+ hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+ fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+ }
+}
diff --git a/flang/tools/bbc/CMakeLists.txt b/flang/tools/bbc/CMakeLists.txt
index 9410fd0056600..69316d4dc61de 100644
--- a/flang/tools/bbc/CMakeLists.txt
+++ b/flang/tools/bbc/CMakeLists.txt
@@ -25,6 +25,7 @@ FIRTransforms
FIRBuilder
HLFIRDialect
HLFIRTransforms
+FlangOpenMPTransforms
${dialect_libs}
${extension_libs}
MLIRAffineToStandard
diff --git a/flang/tools/fir-opt/CMakeLists.txt b/flang/tools/fir-opt/CMakeLists.txt
index 43679a9d53578..4c6dbf7d9c8c3 100644
--- a/flang/tools/fir-opt/CMakeLists.txt
+++ b/flang/tools/fir-opt/CMakeLists.txt
@@ -19,6 +19,7 @@ target_link_libraries(fir-opt PRIVATE
FIRCodeGen
HLFIRDialect
HLFIRTransforms
+ FlangOpenMPTransforms
FIRAnalysis
${test_libs}
${dialect_libs}
diff --git a/flang/tools/fir-opt/fir-opt.cpp b/flang/tools/fir-opt/fir-opt.cpp
index 1846c1b317848..f75fba27c68f0 100644
--- a/flang/tools/fir-opt/fir-opt.cpp
+++ b/flang/tools/fir-opt/fir-opt.cpp
@@ -14,6 +14,7 @@
#include "mlir/Tools/mlir-opt/MlirOptMain.h"
#include "flang/Optimizer/CodeGen/CodeGen.h"
#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
#include "flang/Optimizer/Support/InitFIR.h"
#include "flang/Optimizer/Transforms/Passes.h"
@@ -34,6 +35,7 @@ int main(int argc, char **argv) {
fir::registerOptCodeGenPasses();
fir::registerOptTransformPasses();
hlfir::registerHLFIRPasses();
+ flangomp::registerFlangOpenMPPasses();
#ifdef FLANG_INCLUDE_TESTS
fir::test::registerTestFIRAliasAnalysisPass();
mlir::registerSideEffectTestPasses();
diff --git a/flang/tools/tco/CMakeLists.txt b/flang/tools/tco/CMakeLists.txt
index 808219ac361f2..698a398547c77 100644
--- a/flang/tools/tco/CMakeLists.txt
+++ b/flang/tools/tco/CMakeLists.txt
@@ -17,6 +17,7 @@ target_link_libraries(tco PRIVATE
FIRBuilder
HLFIRDialect
HLFIRTransforms
+ FlangOpenMPTransforms
${dialect_libs}
${extension_libs}
MLIRIR
More information about the llvm-branch-commits
mailing list