[llvm-branch-commits] [flang] [WIP][flang] Introduce HLFIR lowerings to omp.workshare_loop_nest (PR #104748)
Ivan R. Ivanov via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Aug 22 01:59:29 PDT 2024
https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/104748
>From d343f3a86f56864757ccdf889fd6897d9d9507e9 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:09:09 +0900
Subject: [PATCH 1/6] [MLIR][omp] Add omp.workshare op
Add custom omp loop wrapper
Add recursive memory effects trait to workshare
Remove stray include
Remove omp.workshare verifier
Add assembly format for wrapper and add test
Add verification and descriptions
---
.../Dialect/OpenMP/OpenMPClauseOperands.h | 2 +
mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td | 43 ++++++++++++
mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 23 +++++++
mlir/test/Dialect/OpenMP/invalid.mlir | 42 +++++++++++
mlir/test/Dialect/OpenMP/ops.mlir | 69 +++++++++++++++++++
5 files changed, 179 insertions(+)
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
index 38e4d8f245e4fa..896ca9581c3fc8 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPClauseOperands.h
@@ -316,6 +316,8 @@ using TeamsOperands =
detail::Clauses<AllocateClauseOps, IfClauseOps, NumTeamsClauseOps,
PrivateClauseOps, ReductionClauseOps, ThreadLimitClauseOps>;
+using WorkshareOperands = detail::Clauses<NowaitClauseOps>;
+
using WsloopOperands =
detail::Clauses<AllocateClauseOps, LinearClauseOps, NowaitClauseOps,
OrderClauseOps, OrderedClauseOps, PrivateClauseOps,
diff --git a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
index d63fdd88f79104..dc50307de78230 100644
--- a/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
+++ b/mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
@@ -287,6 +287,49 @@ def SingleOp : OpenMP_Op<"single", traits = [
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// 2.8.3 Workshare Construct
+//===----------------------------------------------------------------------===//
+
+def WorkshareOp : OpenMP_Op<"workshare", traits = [
+ RecursiveMemoryEffects,
+ ], clauses = [
+ OpenMP_NowaitClause,
+ ], singleRegion = true> {
+ let summary = "workshare directive";
+ let description = [{
+ The workshare construct divides the execution of the enclosed structured
+ block into separate units of work, and causes the threads of the team to
+ share the work such that each unit is executed only once by one thread, in
+ the context of its implicit task
+
+ This operation is used for the intermediate representation of the workshare
+ block before the work gets divided between the threads. See the flang
+ LowerWorkshare pass for details.
+ }] # clausesDescription;
+
+ let builders = [
+ OpBuilder<(ins CArg<"const WorkshareOperands &">:$clauses)>
+ ];
+}
+
+def WorkshareLoopWrapperOp : OpenMP_Op<"workshare.loop_wrapper", traits = [
+ DeclareOpInterfaceMethods<LoopWrapperInterface>,
+ RecursiveMemoryEffects, SingleBlock
+ ], singleRegion = true> {
+ let summary = "contains loop nests to be parallelized by workshare";
+ let description = [{
+ This operation wraps a loop nest that is marked for dividing into units of
+ work by an encompassing omp.workshare operation.
+ }];
+
+ let builders = [
+ OpBuilder<(ins), [{ build($_builder, $_state, {}); }]>
+ ];
+ let assemblyFormat = "$region attr-dict";
+ let hasVerifier = 1;
+}
+
//===----------------------------------------------------------------------===//
// Loop Nest
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 4c943ebbe3144f..f4acbd97ca6d1a 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -1689,6 +1689,29 @@ LogicalResult SingleOp::verify() {
getCopyprivateSyms());
}
+//===----------------------------------------------------------------------===//
+// WorkshareOp
+//===----------------------------------------------------------------------===//
+
+void WorkshareOp::build(OpBuilder &builder, OperationState &state,
+ const WorkshareOperands &clauses) {
+ WorkshareOp::build(builder, state, clauses.nowait);
+}
+
+//===----------------------------------------------------------------------===//
+// WorkshareLoopWrapperOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult WorkshareLoopWrapperOp::verify() {
+ if (!isWrapper())
+ return emitOpError() << "must be a loop wrapper";
+ if (getNestedWrapper())
+ return emitError() << "nested wrappers not supported";
+ if (!(*this)->getParentOfType<WorkshareOp>())
+ return emitError() << "must be nested in an omp.workshare";
+ return success();
+}
+
//===----------------------------------------------------------------------===//
// WsloopOp
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/OpenMP/invalid.mlir b/mlir/test/Dialect/OpenMP/invalid.mlir
index c76b07ec94a597..c330f4c8d0f76a 100644
--- a/mlir/test/Dialect/OpenMP/invalid.mlir
+++ b/mlir/test/Dialect/OpenMP/invalid.mlir
@@ -2545,3 +2545,45 @@ func.func @omp_taskloop_invalid_composite(%lb: index, %ub: index, %step: index)
} {omp.composite}
return
}
+
+// -----
+func.func @nested_wrapper(%idx : index) {
+ omp.workshare {
+ // expected-error @below {{nested wrappers not supported}}
+ omp.workshare.loop_wrapper {
+ omp.simd {
+ omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+ omp.yield
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// -----
+func.func @not_wrapper() {
+ omp.workshare {
+ // expected-error @below {{must be a loop wrapper}}
+ omp.workshare.loop_wrapper {
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// -----
+func.func @missing_workshare(%idx : index) {
+ // expected-error @below {{must be nested in an omp.workshare}}
+ omp.workshare.loop_wrapper {
+ omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+ omp.yield
+ }
+ omp.terminator
+ }
+ return
+}
diff --git a/mlir/test/Dialect/OpenMP/ops.mlir b/mlir/test/Dialect/OpenMP/ops.mlir
index 9ac97e069addd2..0e1f5ebb1a3739 100644
--- a/mlir/test/Dialect/OpenMP/ops.mlir
+++ b/mlir/test/Dialect/OpenMP/ops.mlir
@@ -2810,3 +2810,72 @@ func.func @omp_target_private(%map1: memref<?xi32>, %map2: memref<?xi32>, %priv_
return
}
+
+// CHECK-LABEL: func @omp_workshare
+func.func @omp_workshare() {
+ // CHECK: omp.workshare {
+ omp.workshare {
+ "test.payload"() : () -> ()
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func @omp_workshare_nowait
+func.func @omp_workshare_nowait() {
+ // CHECK: omp.workshare nowait {
+ omp.workshare nowait {
+ "test.payload"() : () -> ()
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func @omp_workshare_multiple_blocks
+func.func @omp_workshare_multiple_blocks() {
+ // CHECK: omp.workshare {
+ omp.workshare {
+ cf.br ^bb2
+ ^bb2:
+ // CHECK: omp.terminator
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func @omp_workshare.loop_wrapper
+func.func @omp_workshare.loop_wrapper(%idx : index) {
+ // CHECK-NEXT: omp.workshare {
+ omp.workshare {
+ // CHECK-NEXT: omp.workshare.loop_wrapper
+ omp.workshare.loop_wrapper {
+ // CHECK-NEXT: omp.loop_nest
+ omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+ omp.yield
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func @omp_workshare.loop_wrapper_attrs
+func.func @omp_workshare.loop_wrapper_attrs(%idx : index) {
+ // CHECK-NEXT: omp.workshare {
+ omp.workshare {
+ // CHECK-NEXT: omp.workshare.loop_wrapper {
+ omp.workshare.loop_wrapper {
+ // CHECK-NEXT: omp.loop_nest
+ omp.loop_nest (%iv) : index = (%idx) to (%idx) step (%idx) {
+ omp.yield
+ }
+ omp.terminator
+ // CHECK: } {attr_in_dict}
+ } {attr_in_dict}
+ omp.terminator
+ }
+ return
+}
>From d4310a06639c6cd1565aac2d2bbfebffcf9e175d Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:11:47 +0900
Subject: [PATCH 2/6] [flang][omp] Emit omp.workshare in frontend
Fix lower test for workshare
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 30 +++++++++++++++++++++++----
flang/test/Lower/OpenMP/workshare.f90 | 6 +++---
2 files changed, 29 insertions(+), 7 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index d614db8b68ef65..83c90374afa5e3 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1272,6 +1272,15 @@ static void genTaskwaitClauses(lower::AbstractConverter &converter,
loc, llvm::omp::Directive::OMPD_taskwait);
}
+static void genWorkshareClauses(lower::AbstractConverter &converter,
+ semantics::SemanticsContext &semaCtx,
+ lower::StatementContext &stmtCtx,
+ const List<Clause> &clauses, mlir::Location loc,
+ mlir::omp::WorkshareOperands &clauseOps) {
+ ClauseProcessor cp(converter, semaCtx, clauses);
+ cp.processNowait(clauseOps);
+}
+
static void genTeamsClauses(lower::AbstractConverter &converter,
semantics::SemanticsContext &semaCtx,
lower::StatementContext &stmtCtx,
@@ -1897,6 +1906,22 @@ genTaskyieldOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
return converter.getFirOpBuilder().create<mlir::omp::TaskyieldOp>(loc);
}
+static mlir::omp::WorkshareOp
+genWorkshareOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
+ semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
+ mlir::Location loc, const ConstructQueue &queue,
+ ConstructQueue::iterator item) {
+ lower::StatementContext stmtCtx;
+ mlir::omp::WorkshareOperands clauseOps;
+ genWorkshareClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
+
+ return genOpWithBody<mlir::omp::WorkshareOp>(
+ OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
+ llvm::omp::Directive::OMPD_workshare)
+ .setClauses(&item->clauses),
+ queue, item, clauseOps);
+}
+
static mlir::omp::TeamsOp
genTeamsOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
@@ -2309,10 +2334,7 @@ static void genOMPDispatch(lower::AbstractConverter &converter,
llvm::omp::getOpenMPDirectiveName(dir) + ")");
// case llvm::omp::Directive::OMPD_workdistribute:
case llvm::omp::Directive::OMPD_workshare:
- // FIXME: Workshare is not a commonly used OpenMP construct, an
- // implementation for this feature will come later. For the codes
- // that use this construct, add a single construct for now.
- genSingleOp(converter, symTable, semaCtx, eval, loc, queue, item);
+ genWorkshareOp(converter, symTable, semaCtx, eval, loc, queue, item);
break;
default:
// Combined and composite constructs should have been split into a sequence
diff --git a/flang/test/Lower/OpenMP/workshare.f90 b/flang/test/Lower/OpenMP/workshare.f90
index 1e11677a15e1f0..8e771952f5b6da 100644
--- a/flang/test/Lower/OpenMP/workshare.f90
+++ b/flang/test/Lower/OpenMP/workshare.f90
@@ -6,7 +6,7 @@ subroutine sb1(arr)
integer :: arr(:)
!CHECK: omp.parallel {
!$omp parallel
-!CHECK: omp.single {
+!CHECK: omp.workshare {
!$omp workshare
arr = 0
!$omp end workshare
@@ -20,7 +20,7 @@ subroutine sb2(arr)
integer :: arr(:)
!CHECK: omp.parallel {
!$omp parallel
-!CHECK: omp.single nowait {
+!CHECK: omp.workshare nowait {
!$omp workshare
arr = 0
!$omp end workshare nowait
@@ -33,7 +33,7 @@ subroutine sb2(arr)
subroutine sb3(arr)
integer :: arr(:)
!CHECK: omp.parallel {
-!CHECK: omp.single {
+!CHECK: omp.workshare {
!$omp parallel workshare
arr = 0
!$omp end parallel workshare
>From fb06794ba6259fc3bcc3b9c73108a03e77a0b42d Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Thu, 22 Aug 2024 17:01:43 +0900
Subject: [PATCH 3/6] Fix function signature
---
flang/lib/Lower/OpenMP/OpenMP.cpp | 10 ++++++----
1 file changed, 6 insertions(+), 4 deletions(-)
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 83c90374afa5e3..086dd7aaeaab88 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1908,12 +1908,14 @@ genTaskyieldOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
static mlir::omp::WorkshareOp
genWorkshareOp(lower::AbstractConverter &converter, lower::SymMap &symTable,
- semantics::SemanticsContext &semaCtx, lower::pft::Evaluation &eval,
- mlir::Location loc, const ConstructQueue &queue,
- ConstructQueue::iterator item) {
+ semantics::SemanticsContext &semaCtx,
+ lower::pft::Evaluation &eval, mlir::Location loc,
+ const ConstructQueue &queue,
+ ConstructQueue::const_iterator item) {
lower::StatementContext stmtCtx;
mlir::omp::WorkshareOperands clauseOps;
- genWorkshareClauses(converter, semaCtx, stmtCtx, item->clauses, loc, clauseOps);
+ genWorkshareClauses(converter, semaCtx, stmtCtx, item->clauses, loc,
+ clauseOps);
return genOpWithBody<mlir::omp::WorkshareOp>(
OpWithBodyGenInfo(converter, symTable, semaCtx, loc, eval,
>From c3ec4f1bd4a51139a10b1450e3e194a9270b7362 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:12:34 +0900
Subject: [PATCH 4/6] [flang] Introduce ws loop nest generation for HLFIR
lowering
Emit loop nests in a custom wrapper
Only emit unordered loops as omp loops
Fix uninitialized memory bug in genLoopNest
---
.../flang/Optimizer/Builder/HLFIRTools.h | 12 +++--
flang/lib/Lower/ConvertCall.cpp | 2 +-
flang/lib/Lower/OpenMP/ReductionProcessor.cpp | 4 +-
flang/lib/Optimizer/Builder/HLFIRTools.cpp | 52 ++++++++++++++-----
.../HLFIR/Transforms/BufferizeHLFIR.cpp | 3 +-
.../LowerHLFIROrderedAssignments.cpp | 30 +++++------
.../Transforms/OptimizedBufferization.cpp | 6 +--
7 files changed, 69 insertions(+), 40 deletions(-)
diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 6b41025eea0780..f073f494b3fb21 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -357,8 +357,8 @@ hlfir::ElementalOp genElementalOp(
/// Structure to describe a loop nest.
struct LoopNest {
- fir::DoLoopOp outerLoop;
- fir::DoLoopOp innerLoop;
+ mlir::Operation *outerOp = nullptr;
+ mlir::Block *body = nullptr;
llvm::SmallVector<mlir::Value> oneBasedIndices;
};
@@ -366,11 +366,13 @@ struct LoopNest {
/// \p isUnordered specifies whether the loops in the loop nest
/// are unordered.
LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
- mlir::ValueRange extents, bool isUnordered = false);
+ mlir::ValueRange extents, bool isUnordered = false,
+ bool emitWorkshareLoop = false);
inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
- mlir::Value shape, bool isUnordered = false) {
+ mlir::Value shape, bool isUnordered = false,
+ bool emitWorkshareLoop = false) {
return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
- isUnordered);
+ isUnordered, emitWorkshareLoop);
}
/// Inline the body of an hlfir.elemental at the current insertion point
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index fd873f55dd844e..0689d6e033dd9c 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -2128,7 +2128,7 @@ class ElementalCallBuilder {
hlfir::genLoopNest(loc, builder, shape, !mustBeOrdered);
mlir::ValueRange oneBasedIndices = loopNest.oneBasedIndices;
auto insPt = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+ builder.setInsertionPointToStart(loopNest.body);
callContext.stmtCtx.pushScope();
for (auto &preparedActual : loweredActuals)
if (preparedActual)
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c3c1f363033c27..72a90dd0d6f29d 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -375,7 +375,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
// know this won't miss any opportuinties for clever elemental inlining
hlfir::LoopNest nest = hlfir::genLoopNest(
loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
- builder.setInsertionPointToStart(nest.innerLoop.getBody());
+ builder.setInsertionPointToStart(nest.body);
mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
@@ -389,7 +389,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
builder, loc, redId, refTy, lhsEle, rhsEle);
builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
- builder.setInsertionPointAfter(nest.outerLoop);
+ builder.setInsertionPointAfter(nest.outerOp);
builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
}
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 8d0ae2f195178c..333331378841ed 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -20,6 +20,7 @@
#include "mlir/IR/IRMapping.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/TypeSwitch.h"
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
#include <optional>
// Return explicit extents. If the base is a fir.box, this won't read it to
@@ -855,26 +856,51 @@ mlir::Value hlfir::inlineElementalOp(
hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
fir::FirOpBuilder &builder,
- mlir::ValueRange extents, bool isUnordered) {
+ mlir::ValueRange extents, bool isUnordered,
+ bool emitWorkshareLoop) {
+ emitWorkshareLoop = emitWorkshareLoop && isUnordered;
hlfir::LoopNest loopNest;
assert(!extents.empty() && "must have at least one extent");
- auto insPt = builder.saveInsertionPoint();
+ mlir::OpBuilder::InsertionGuard guard(builder);
loopNest.oneBasedIndices.assign(extents.size(), mlir::Value{});
// Build loop nest from column to row.
auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
mlir::Type indexType = builder.getIndexType();
- unsigned dim = extents.size() - 1;
- for (auto extent : llvm::reverse(extents)) {
- auto ub = builder.createConvert(loc, indexType, extent);
- loopNest.innerLoop =
- builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
- builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
- // Reverse the indices so they are in column-major order.
- loopNest.oneBasedIndices[dim--] = loopNest.innerLoop.getInductionVar();
- if (!loopNest.outerLoop)
- loopNest.outerLoop = loopNest.innerLoop;
+ if (emitWorkshareLoop) {
+ auto wslw = builder.create<mlir::omp::WorkshareLoopWrapperOp>(loc);
+ loopNest.outerOp = wslw;
+ builder.createBlock(&wslw.getRegion());
+ mlir::omp::LoopNestOperands lnops;
+ lnops.loopInclusive = builder.getUnitAttr();
+ for (auto extent : llvm::reverse(extents)) {
+ lnops.loopLowerBounds.push_back(one);
+ lnops.loopUpperBounds.push_back(extent);
+ lnops.loopSteps.push_back(one);
+ }
+ auto lnOp = builder.create<mlir::omp::LoopNestOp>(loc, lnops);
+ builder.create<mlir::omp::TerminatorOp>(loc);
+ mlir::Block *block = builder.createBlock(&lnOp.getRegion());
+ for (auto extent : llvm::reverse(extents))
+ block->addArgument(extent.getType(), extent.getLoc());
+ loopNest.body = block;
+ builder.create<mlir::omp::YieldOp>(loc);
+ for (unsigned dim = 0; dim < extents.size(); dim++)
+ loopNest.oneBasedIndices[extents.size() - dim - 1] =
+ lnOp.getRegion().front().getArgument(dim);
+ } else {
+ unsigned dim = extents.size() - 1;
+ for (auto extent : llvm::reverse(extents)) {
+ auto ub = builder.createConvert(loc, indexType, extent);
+ auto doLoop =
+ builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
+ loopNest.body = doLoop.getBody();
+ builder.setInsertionPointToStart(loopNest.body);
+ // Reverse the indices so they are in column-major order.
+ loopNest.oneBasedIndices[dim--] = doLoop.getInductionVar();
+ if (!loopNest.outerOp)
+ loopNest.outerOp = doLoop;
+ }
}
- builder.restoreInsertionPoint(insPt);
return loopNest;
}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index a70a6b388c4b1a..b608677c526310 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -31,6 +31,7 @@
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/ADT/TypeSwitch.h"
namespace hlfir {
@@ -793,7 +794,7 @@ struct ElementalOpConversion
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
auto insPt = builder.saveInsertionPoint();
- builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+ builder.setInsertionPointToStart(loopNest.body);
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
loopNest.oneBasedIndices);
hlfir::Entity elementValue(yield.getElementValue());
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
index 85dd517cb57914..645abf65d10a32 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
@@ -464,7 +464,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
// if the LHS is not).
mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
- builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
+ builder.setInsertionPointToStart(elementalLoopNest->body);
lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
elementalLoopNest->oneBasedIndices);
rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
@@ -484,7 +484,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
for (auto &cleanupConversion : argConversionCleanups)
cleanupConversion();
if (elementalLoopNest)
- builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
+ builder.setInsertionPointAfter(elementalLoopNest->outerOp);
} else {
// TODO: preserve allocatable assignment aspects for forall once
// they are conveyed in hlfir.region_assign.
@@ -493,7 +493,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
generateCleanupIfAny(loweredLhs.elementalCleanup);
if (loweredLhs.vectorSubscriptLoopNest)
builder.setInsertionPointAfter(
- loweredLhs.vectorSubscriptLoopNest->outerLoop);
+ loweredLhs.vectorSubscriptLoopNest->outerOp);
generateCleanupIfAny(oldRhsYield);
generateCleanupIfAny(loweredLhs.nonElementalCleanup);
}
@@ -518,8 +518,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
hlfir::Entity savedMask{maybeSaved->first};
mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
- constructStack.push_back(whereLoopNest->outerLoop.getOperation());
- builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
+ constructStack.push_back(whereLoopNest->outerOp);
+ builder.setInsertionPointToStart(whereLoopNest->body);
mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
whereLoopNest->oneBasedIndices);
generateMaskIfOp(cdt);
@@ -527,7 +527,7 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
// If this is the same run as the one that saved the value, the clean-up
// was left-over to be done now.
auto insertionPoint = builder.saveInsertionPoint();
- builder.setInsertionPointAfter(whereLoopNest->outerLoop);
+ builder.setInsertionPointAfter(whereLoopNest->outerOp);
generateCleanupIfAny(maybeSaved->second);
builder.restoreInsertionPoint(insertionPoint);
}
@@ -539,8 +539,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
mask.generateNoneElementalPart(builder, mapper);
mlir::Value shape = mask.generateShape(builder, mapper);
whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
- constructStack.push_back(whereLoopNest->outerLoop.getOperation());
- builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
+ constructStack.push_back(whereLoopNest->outerOp);
+ builder.setInsertionPointToStart(whereLoopNest->body);
mlir::Value cdt = generateMaskedEntity(mask);
generateMaskIfOp(cdt);
return;
@@ -754,7 +754,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
loc, builder, loweredLhs.vectorSubscriptShape.value());
builder.setInsertionPointToStart(
- loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
+ loweredLhs.vectorSubscriptLoopNest->body);
}
loweredLhs.lhs = temp->second.fetch(loc, builder);
return loweredLhs;
@@ -772,7 +772,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
!elementalAddrLhs.isOrdered());
builder.setInsertionPointToStart(
- loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
+ loweredLhs.vectorSubscriptLoopNest->body);
mapper.map(elementalAddrLhs.getIndices(),
loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
@@ -798,11 +798,11 @@ OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
if (!maskedExpr.noneElementalPartWasGenerated) {
// Generate none elemental part before the where loops (but inside the
// current forall loops if any).
- builder.setInsertionPoint(whereLoopNest->outerLoop);
+ builder.setInsertionPoint(whereLoopNest->outerOp);
maskedExpr.generateNoneElementalPart(builder, mapper);
}
// Generate the none elemental part cleanup after the where loops.
- builder.setInsertionPointAfter(whereLoopNest->outerLoop);
+ builder.setInsertionPointAfter(whereLoopNest->outerOp);
maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
// Generate the value of the current element for the masked expression
// at the current insertion point (inside the where loops, and any fir.if
@@ -1242,7 +1242,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
fir::factory::TemporaryStorage *temp = nullptr;
if (loweredLhs.vectorSubscriptLoopNest)
- constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
+ constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
// Vector subscripted entity for which the shape must also be saved on top
// of the element addresses (e.g. the shape may change in each forall
@@ -1265,7 +1265,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
// subscripted LHS.
auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
auto insertionPoint = builder.saveInsertionPoint();
- builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
+ builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
vectorTmp.pushShape(loc, builder, shape);
builder.restoreInsertionPoint(insertionPoint);
} else {
@@ -1291,7 +1291,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
if (loweredLhs.vectorSubscriptLoopNest) {
constructStack.pop_back();
builder.setInsertionPointAfter(
- loweredLhs.vectorSubscriptLoopNest->outerLoop);
+ loweredLhs.vectorSubscriptLoopNest->outerOp);
}
}
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 7553e05b470634..3a0a98dc594463 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -483,7 +483,7 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
// hlfir.elemental region inside the inner loop
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
- builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+ builder.setInsertionPointToStart(loopNest.body);
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
loopNest.oneBasedIndices);
hlfir::Entity elementValue{yield.getElementValue()};
@@ -554,7 +554,7 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
- builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+ builder.setInsertionPointToStart(loopNest.body);
auto arrayElement =
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
@@ -649,7 +649,7 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
- builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+ builder.setInsertionPointToStart(loopNest.body);
auto rhsArrayElement =
hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);
>From acee889bbdb84cc70760a9c298cd7fa534473f54 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 22:06:55 +0900
Subject: [PATCH 5/6] [flang] Lower omp.workshare to other omp constructs
Change to workshare loop wrapper op
Move single op declaration
Schedule pass properly
Correctly handle nested nested loop nests to be parallelized by workshare
Leave comments for shouldUseWorkshareLowering
Use copyprivate to scatter val from omp.single
TODO still need to implement copy function
TODO transitive check for usage outside of omp.single not imiplemented yet
Transitively check for users outisde of single op
TODO need to implement copy func
TODO need to hoist allocas outside of single regions
Add tests
Hoist allocas
More tests
Emit body for copy func
Test the tmp storing logic
Clean up trivially dead ops
Only handle single-block regions for now
Fix tests for custom assembly for loop wrapper
Only run the lower workshare pass if openmp is enabled
Implement some missing functionality
Fix tests
Fix test
Iterate backwards to find all trivially dead ops
Add expalanation comment for createCopyFun
Update test
---
flang/include/flang/Optimizer/OpenMP/Passes.h | 5 +
.../include/flang/Optimizer/OpenMP/Passes.td | 5 +
flang/include/flang/Tools/CLOptions.inc | 6 +-
flang/include/flang/Tools/CrossToolHelpers.h | 1 +
flang/lib/Frontend/FrontendActions.cpp | 10 +-
flang/lib/Optimizer/OpenMP/CMakeLists.txt | 1 +
flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp | 446 ++++++++++++++++++
flang/test/Fir/basic-program.fir | 1 +
.../Transforms/OpenMP/lower-workshare.mlir | 189 ++++++++
.../Transforms/OpenMP/lower-workshare2.mlir | 23 +
.../Transforms/OpenMP/lower-workshare3.mlir | 74 +++
.../Transforms/OpenMP/lower-workshare4.mlir | 59 +++
.../Transforms/OpenMP/lower-workshare5.mlir | 42 ++
.../Transforms/OpenMP/lower-workshare6.mlir | 51 ++
flang/tools/bbc/bbc.cpp | 5 +-
flang/tools/tco/tco.cpp | 1 +
16 files changed, 915 insertions(+), 4 deletions(-)
create mode 100644 flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
create mode 100644 flang/test/Transforms/OpenMP/lower-workshare.mlir
create mode 100644 flang/test/Transforms/OpenMP/lower-workshare2.mlir
create mode 100644 flang/test/Transforms/OpenMP/lower-workshare3.mlir
create mode 100644 flang/test/Transforms/OpenMP/lower-workshare4.mlir
create mode 100644 flang/test/Transforms/OpenMP/lower-workshare5.mlir
create mode 100644 flang/test/Transforms/OpenMP/lower-workshare6.mlir
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.h b/flang/include/flang/Optimizer/OpenMP/Passes.h
index 403d79667bf448..feb395f1a12dbd 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.h
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.h
@@ -25,6 +25,11 @@ namespace flangomp {
#define GEN_PASS_REGISTRATION
#include "flang/Optimizer/OpenMP/Passes.h.inc"
+/// Impelements the logic specified in the 2.8.3 workshare Construct section of
+/// the OpenMP standard which specifies what statements or constructs shall be
+/// divided into units of work.
+bool shouldUseWorkshareLowering(mlir::Operation *op);
+
} // namespace flangomp
#endif // FORTRAN_OPTIMIZER_OPENMP_PASSES_H
diff --git a/flang/include/flang/Optimizer/OpenMP/Passes.td b/flang/include/flang/Optimizer/OpenMP/Passes.td
index 395178e26a5762..041240cad12eb3 100644
--- a/flang/include/flang/Optimizer/OpenMP/Passes.td
+++ b/flang/include/flang/Optimizer/OpenMP/Passes.td
@@ -37,4 +37,9 @@ def FunctionFiltering : Pass<"omp-function-filtering"> {
];
}
+// Needs to be scheduled on Module as we create functions in it
+def LowerWorkshare : Pass<"lower-workshare", "::mlir::ModuleOp"> {
+ let summary = "Lower workshare construct";
+}
+
#endif //FORTRAN_OPTIMIZER_OPENMP_PASSES
diff --git a/flang/include/flang/Tools/CLOptions.inc b/flang/include/flang/Tools/CLOptions.inc
index 1881e23b00045a..bb00e079008a0b 100644
--- a/flang/include/flang/Tools/CLOptions.inc
+++ b/flang/include/flang/Tools/CLOptions.inc
@@ -337,7 +337,7 @@ inline void createDefaultFIROptimizerPassPipeline(
/// \param optLevel - optimization level used for creating FIR optimization
/// passes pipeline
inline void createHLFIRToFIRPassPipeline(
- mlir::PassManager &pm, llvm::OptimizationLevel optLevel = defaultOptLevel) {
+ mlir::PassManager &pm, bool enableOpenMP, llvm::OptimizationLevel optLevel = defaultOptLevel) {
if (optLevel.isOptimizingForSpeed()) {
addCanonicalizerPassWithoutRegionSimplification(pm);
addNestedPassToAllTopLevelOperations(
@@ -354,6 +354,8 @@ inline void createHLFIRToFIRPassPipeline(
pm.addPass(hlfir::createLowerHLFIRIntrinsics());
pm.addPass(hlfir::createBufferizeHLFIR());
pm.addPass(hlfir::createConvertHLFIRtoFIR());
+ if (enableOpenMP)
+ pm.addPass(flangomp::createLowerWorkshare());
}
/// Create a pass pipeline for handling certain OpenMP transformations needed
@@ -425,7 +427,7 @@ inline void createDefaultFIRCodeGenPassPipeline(mlir::PassManager &pm,
/// passes pipeline
inline void createMLIRToLLVMPassPipeline(mlir::PassManager &pm,
MLIRToLLVMPassPipelineConfig &config, llvm::StringRef inputFilename = {}) {
- fir::createHLFIRToFIRPassPipeline(pm, config.OptLevel);
+ fir::createHLFIRToFIRPassPipeline(pm, config.EnableOpenMP, config.OptLevel);
// Add default optimizer pass pipeline.
fir::createDefaultFIROptimizerPassPipeline(pm, config);
diff --git a/flang/include/flang/Tools/CrossToolHelpers.h b/flang/include/flang/Tools/CrossToolHelpers.h
index 75fd783af237d0..0911b9bca67332 100644
--- a/flang/include/flang/Tools/CrossToolHelpers.h
+++ b/flang/include/flang/Tools/CrossToolHelpers.h
@@ -123,6 +123,7 @@ struct MLIRToLLVMPassPipelineConfig : public FlangEPCallBacks {
false; ///< Set no-signed-zeros-fp-math attribute for functions.
bool UnsafeFPMath = false; ///< Set unsafe-fp-math attribute for functions.
bool NSWOnLoopVarInc = false; ///< Add nsw flag to loop variable increments.
+ bool EnableOpenMP = false; ///< Enable OpenMP lowering.
};
struct OffloadModuleOpts {
diff --git a/flang/lib/Frontend/FrontendActions.cpp b/flang/lib/Frontend/FrontendActions.cpp
index 5c86bd947ce73f..db5c5649337528 100644
--- a/flang/lib/Frontend/FrontendActions.cpp
+++ b/flang/lib/Frontend/FrontendActions.cpp
@@ -711,7 +711,11 @@ void CodeGenAction::lowerHLFIRToFIR() {
pm.enableVerifier(/*verifyPasses=*/true);
// Create the pass pipeline
- fir::createHLFIRToFIRPassPipeline(pm, level);
+ fir::createHLFIRToFIRPassPipeline(
+ pm,
+ ci.getInvocation().getFrontendOpts().features.IsEnabled(
+ Fortran::common::LanguageFeature::OpenMP),
+ level);
(void)mlir::applyPassManagerCLOptions(pm);
if (!mlir::succeeded(pm.run(*mlirModule))) {
@@ -824,6 +828,10 @@ void CodeGenAction::generateLLVMIR() {
config.VScaleMax = vsr->second;
}
+ if (ci.getInvocation().getFrontendOpts().features.IsEnabled(
+ Fortran::common::LanguageFeature::OpenMP))
+ config.EnableOpenMP = true;
+
if (ci.getInvocation().getLoweringOpts().getNSWOnLoopVarInc())
config.NSWOnLoopVarInc = true;
diff --git a/flang/lib/Optimizer/OpenMP/CMakeLists.txt b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
index 92051634f0378b..39e92d388288d4 100644
--- a/flang/lib/Optimizer/OpenMP/CMakeLists.txt
+++ b/flang/lib/Optimizer/OpenMP/CMakeLists.txt
@@ -4,6 +4,7 @@ add_flang_library(FlangOpenMPTransforms
FunctionFiltering.cpp
MapInfoFinalization.cpp
MarkDeclareTarget.cpp
+ LowerWorkshare.cpp
DEPENDS
FIRDialect
diff --git a/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
new file mode 100644
index 00000000000000..6e5538b54ba5e0
--- /dev/null
+++ b/flang/lib/Optimizer/OpenMP/LowerWorkshare.cpp
@@ -0,0 +1,446 @@
+//===- LowerWorkshare.cpp - special cases for bufferization -------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file implements the lowering of omp.workshare to other omp constructs.
+//
+// This pass is tasked with parallelizing the loops nested in
+// workshare.loop_wrapper while both the Fortran to mlir lowering and the hlfir
+// to fir lowering pipelines are responsible for emitting the
+// workshare.loop_wrapper ops where appropriate according to the
+// `shouldUseWorkshareLowering` function.
+//
+//===----------------------------------------------------------------------===//
+
+#include <flang/Optimizer/Builder/FIRBuilder.h>
+#include <flang/Optimizer/Dialect/FIROps.h>
+#include <flang/Optimizer/Dialect/FIRType.h>
+#include <flang/Optimizer/HLFIR/HLFIROps.h>
+#include <flang/Optimizer/OpenMP/Passes.h>
+#include <llvm/ADT/BreadthFirstIterator.h>
+#include <llvm/ADT/STLExtras.h>
+#include <llvm/ADT/SmallVectorExtras.h>
+#include <llvm/ADT/iterator_range.h>
+#include <llvm/Support/ErrorHandling.h>
+#include <mlir/Dialect/Arith/IR/Arith.h>
+#include <mlir/Dialect/LLVMIR/LLVMTypes.h>
+#include <mlir/Dialect/OpenMP/OpenMPClauseOperands.h>
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
+#include <mlir/Dialect/SCF/IR/SCF.h>
+#include <mlir/IR/BuiltinOps.h>
+#include <mlir/IR/IRMapping.h>
+#include <mlir/IR/OpDefinition.h>
+#include <mlir/IR/PatternMatch.h>
+#include <mlir/IR/Visitors.h>
+#include <mlir/Interfaces/SideEffectInterfaces.h>
+#include <mlir/Support/LLVM.h>
+#include <mlir/Transforms/GreedyPatternRewriteDriver.h>
+
+#include <variant>
+
+namespace flangomp {
+#define GEN_PASS_DEF_LOWERWORKSHARE
+#include "flang/Optimizer/OpenMP/Passes.h.inc"
+} // namespace flangomp
+
+#define DEBUG_TYPE "lower-workshare"
+
+using namespace mlir;
+
+namespace flangomp {
+
+// Checks for nesting pattern below as we need to avoid sharing the work of
+// statements which are nested in some constructs such as omp.critical or
+// another omp.parallel.
+//
+// omp.workshare { // `wsOp`
+// ...
+// omp.T { // `parent`
+// ...
+// `op`
+//
+template <typename T>
+static bool isNestedIn(omp::WorkshareOp wsOp, Operation *op) {
+ T parent = op->getParentOfType<T>();
+ if (!parent)
+ return false;
+ return wsOp->isProperAncestor(parent);
+}
+
+bool shouldUseWorkshareLowering(Operation *op) {
+ auto parentWorkshare = op->getParentOfType<omp::WorkshareOp>();
+
+ if (!parentWorkshare)
+ return false;
+
+ if (isNestedIn<omp::CriticalOp>(parentWorkshare, op))
+ return false;
+
+ // 2.8.3 workshare Construct
+ // For a parallel construct, the construct is a unit of work with respect to
+ // the workshare construct. The statements contained in the parallel construct
+ // are executed by a new thread team.
+ if (isNestedIn<omp::ParallelOp>(parentWorkshare, op))
+ return false;
+
+ // 2.8.2 single Construct
+ // Binding The binding thread set for a single region is the current team. A
+ // single region binds to the innermost enclosing parallel region.
+ // Description Only one of the encountering threads will execute the
+ // structured block associated with the single construct.
+ if (isNestedIn<omp::SingleOp>(parentWorkshare, op))
+ return false;
+
+ return true;
+}
+
+} // namespace flangomp
+
+namespace {
+
+struct SingleRegion {
+ Block::iterator begin, end;
+};
+
+static bool mustParallelizeOp(Operation *op) {
+ return op
+ ->walk([&](Operation *nested) {
+ // We need to be careful not to pick up workshare.loop_wrapper in nested
+ // omp.parallel{omp.workshare} regions, i.e. make sure that `nested`
+ // binds to the workshare region we are currently handling.
+ //
+ // For example:
+ //
+ // omp.parallel {
+ // omp.workshare { // currently handling this
+ // omp.parallel {
+ // omp.workshare { // nested workshare
+ // omp.workshare.loop_wrapper {}
+ //
+ // Therefore, we skip if we encounter a nested omp.workshare.
+ if (isa<omp::WorkshareOp>(op))
+ return WalkResult::skip();
+ if (isa<omp::WorkshareLoopWrapperOp>(op))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ })
+ .wasInterrupted();
+}
+
+static bool isSafeToParallelize(Operation *op) {
+ return isa<hlfir::DeclareOp>(op) || isa<fir::DeclareOp>(op) ||
+ isMemoryEffectFree(op);
+}
+
+/// Simple shallow copies suffice for our purposes in this pass, so we implement
+/// this simpler alternative to the full fledged `createCopyFunc` in the
+/// frontend
+static mlir::func::FuncOp createCopyFunc(mlir::Location loc, mlir::Type varType,
+ fir::FirOpBuilder builder) {
+ mlir::ModuleOp module = builder.getModule();
+ auto rt = cast<fir::ReferenceType>(varType);
+ mlir::Type eleTy = rt.getEleTy();
+ std::string copyFuncName =
+ fir::getTypeAsString(eleTy, builder.getKindMap(), "_workshare_copy");
+
+ if (auto decl = module.lookupSymbol<mlir::func::FuncOp>(copyFuncName))
+ return decl;
+ // create function
+ mlir::OpBuilder::InsertionGuard guard(builder);
+ mlir::OpBuilder modBuilder(module.getBodyRegion());
+ llvm::SmallVector<mlir::Type> argsTy = {varType, varType};
+ auto funcType = mlir::FunctionType::get(builder.getContext(), argsTy, {});
+ mlir::func::FuncOp funcOp =
+ modBuilder.create<mlir::func::FuncOp>(loc, copyFuncName, funcType);
+ funcOp.setVisibility(mlir::SymbolTable::Visibility::Private);
+ builder.createBlock(&funcOp.getRegion(), funcOp.getRegion().end(), argsTy,
+ {loc, loc});
+ builder.setInsertionPointToStart(&funcOp.getRegion().back());
+
+ Value loaded = builder.create<fir::LoadOp>(loc, funcOp.getArgument(0));
+ builder.create<fir::StoreOp>(loc, loaded, funcOp.getArgument(1));
+
+ builder.create<mlir::func::ReturnOp>(loc);
+ return funcOp;
+}
+
+static bool isUserOutsideSR(Operation *user, Operation *parentOp,
+ SingleRegion sr) {
+ while (user->getParentOp() != parentOp)
+ user = user->getParentOp();
+ return sr.begin->getBlock() != user->getBlock() ||
+ !(user->isBeforeInBlock(&*sr.end) && sr.begin->isBeforeInBlock(user));
+}
+
+static bool isTransitivelyUsedOutside(Value v, SingleRegion sr) {
+ Block *srBlock = sr.begin->getBlock();
+ Operation *parentOp = srBlock->getParentOp();
+
+ for (auto &use : v.getUses()) {
+ Operation *user = use.getOwner();
+ if (isUserOutsideSR(user, parentOp, sr))
+ return true;
+
+ // Results of nested users cannot be used outside of the SR
+ if (user->getBlock() != srBlock)
+ continue;
+
+ // A non-safe to parallelize operation will be handled separately
+ if (!isSafeToParallelize(user))
+ continue;
+
+ for (auto res : user->getResults())
+ if (isTransitivelyUsedOutside(res, sr))
+ return true;
+ }
+ return false;
+}
+
+/// We clone pure operations in both the parallel and single blocks. this
+/// functions cleans them up if they end up with no uses
+static void cleanupBlock(Block *block) {
+ for (Operation &op : llvm::make_early_inc_range(
+ llvm::make_range(block->rbegin(), block->rend())))
+ if (isOpTriviallyDead(&op))
+ op.erase();
+}
+
+static void parallelizeRegion(Region &sourceRegion, Region &targetRegion,
+ IRMapping &rootMapping, Location loc,
+ mlir::DominanceInfo &di) {
+ OpBuilder rootBuilder(sourceRegion.getContext());
+ ModuleOp m = sourceRegion.getParentOfType<ModuleOp>();
+ OpBuilder copyFuncBuilder(m.getBodyRegion());
+ fir::FirOpBuilder firCopyFuncBuilder(copyFuncBuilder, m);
+
+ auto mapReloadedValue =
+ [&](Value v, OpBuilder allocaBuilder, OpBuilder singleBuilder,
+ OpBuilder parallelBuilder, IRMapping singleMapping) -> Value {
+ if (auto reloaded = rootMapping.lookupOrNull(v))
+ return nullptr;
+ Type ty = v.getType();
+ Value alloc = allocaBuilder.create<fir::AllocaOp>(loc, ty);
+ singleBuilder.create<fir::StoreOp>(loc, singleMapping.lookup(v), alloc);
+ Value reloaded = parallelBuilder.create<fir::LoadOp>(loc, ty, alloc);
+ rootMapping.map(v, reloaded);
+ return alloc;
+ };
+
+ auto moveToSingle = [&](SingleRegion sr, OpBuilder allocaBuilder,
+ OpBuilder singleBuilder,
+ OpBuilder parallelBuilder) -> SmallVector<Value> {
+ IRMapping singleMapping = rootMapping;
+ SmallVector<Value> copyPrivate;
+
+ for (Operation &op : llvm::make_range(sr.begin, sr.end)) {
+ if (isSafeToParallelize(&op)) {
+ singleBuilder.clone(op, singleMapping);
+ parallelBuilder.clone(op, rootMapping);
+ } else if (auto alloca = dyn_cast<fir::AllocaOp>(&op)) {
+ auto hoisted =
+ cast<fir::AllocaOp>(allocaBuilder.clone(*alloca, singleMapping));
+ rootMapping.map(&*alloca, &*hoisted);
+ rootMapping.map(alloca.getResult(), hoisted.getResult());
+ copyPrivate.push_back(hoisted);
+ } else {
+ singleBuilder.clone(op, singleMapping);
+ // Prepare reloaded values for results of operations that cannot be
+ // safely parallelized and which are used after the region `sr`
+ for (auto res : op.getResults()) {
+ if (isTransitivelyUsedOutside(res, sr)) {
+ auto alloc = mapReloadedValue(res, allocaBuilder, singleBuilder,
+ parallelBuilder, singleMapping);
+ if (alloc)
+ copyPrivate.push_back(alloc);
+ }
+ }
+ }
+ }
+ singleBuilder.create<omp::TerminatorOp>(loc);
+ return copyPrivate;
+ };
+
+ for (Block &block : sourceRegion) {
+ Block *targetBlock = rootBuilder.createBlock(
+ &targetRegion, {}, block.getArgumentTypes(),
+ llvm::map_to_vector(block.getArguments(),
+ [](BlockArgument arg) { return arg.getLoc(); }));
+ rootMapping.map(&block, targetBlock);
+ rootMapping.map(block.getArguments(), targetBlock->getArguments());
+ }
+
+ auto handleOneBlock = [&](Block &block) {
+ Block &targetBlock = *rootMapping.lookup(&block);
+ rootBuilder.setInsertionPointToStart(&targetBlock);
+ Operation *terminator = block.getTerminator();
+ SmallVector<std::variant<SingleRegion, Operation *>> regions;
+
+ auto it = block.begin();
+ auto getOneRegion = [&]() {
+ if (&*it == terminator)
+ return false;
+ if (mustParallelizeOp(&*it)) {
+ regions.push_back(&*it);
+ it++;
+ return true;
+ }
+ SingleRegion sr;
+ sr.begin = it;
+ while (&*it != terminator && !mustParallelizeOp(&*it))
+ it++;
+ sr.end = it;
+ assert(sr.begin != sr.end);
+ regions.push_back(sr);
+ return true;
+ };
+ while (getOneRegion())
+ ;
+
+ for (auto [i, opOrSingle] : llvm::enumerate(regions)) {
+ bool isLast = i + 1 == regions.size();
+ if (std::holds_alternative<SingleRegion>(opOrSingle)) {
+ OpBuilder singleBuilder(sourceRegion.getContext());
+ Block *singleBlock = new Block();
+ singleBuilder.setInsertionPointToStart(singleBlock);
+
+ OpBuilder allocaBuilder(sourceRegion.getContext());
+ Block *allocaBlock = new Block();
+ allocaBuilder.setInsertionPointToStart(allocaBlock);
+
+ OpBuilder parallelBuilder(sourceRegion.getContext());
+ Block *parallelBlock = new Block();
+ parallelBuilder.setInsertionPointToStart(parallelBlock);
+
+ omp::SingleOperands singleOperands;
+ if (isLast)
+ singleOperands.nowait = rootBuilder.getUnitAttr();
+ singleOperands.copyprivateVars =
+ moveToSingle(std::get<SingleRegion>(opOrSingle), allocaBuilder,
+ singleBuilder, parallelBuilder);
+ cleanupBlock(singleBlock);
+ for (auto var : singleOperands.copyprivateVars) {
+ mlir::func::FuncOp funcOp =
+ createCopyFunc(loc, var.getType(), firCopyFuncBuilder);
+ singleOperands.copyprivateSyms.push_back(SymbolRefAttr::get(funcOp));
+ }
+ omp::SingleOp singleOp =
+ rootBuilder.create<omp::SingleOp>(loc, singleOperands);
+ singleOp.getRegion().push_back(singleBlock);
+ rootBuilder.getInsertionBlock()->getOperations().splice(
+ rootBuilder.getInsertionPoint(), parallelBlock->getOperations());
+ targetRegion.front().getOperations().splice(
+ singleOp->getIterator(), allocaBlock->getOperations());
+ delete allocaBlock;
+ delete parallelBlock;
+ } else {
+ auto op = std::get<Operation *>(opOrSingle);
+ if (auto wslw = dyn_cast<omp::WorkshareLoopWrapperOp>(op)) {
+ omp::WsloopOperands wsloopOperands;
+ if (isLast)
+ wsloopOperands.nowait = rootBuilder.getUnitAttr();
+ auto wsloop =
+ rootBuilder.create<mlir::omp::WsloopOp>(loc, wsloopOperands);
+ auto clonedWslw = cast<omp::WorkshareLoopWrapperOp>(
+ rootBuilder.clone(*wslw, rootMapping));
+ wsloop.getRegion().takeBody(clonedWslw.getRegion());
+ clonedWslw->erase();
+ } else {
+ assert(mustParallelizeOp(op));
+ Operation *cloned = rootBuilder.cloneWithoutRegions(*op, rootMapping);
+ for (auto [region, clonedRegion] :
+ llvm::zip(op->getRegions(), cloned->getRegions()))
+ parallelizeRegion(region, clonedRegion, rootMapping, loc, di);
+ }
+ }
+ }
+
+ rootBuilder.clone(*block.getTerminator(), rootMapping);
+ };
+
+ if (sourceRegion.hasOneBlock()) {
+ handleOneBlock(sourceRegion.front());
+ } else {
+ auto &domTree = di.getDomTree(&sourceRegion);
+ for (auto node : llvm::breadth_first(domTree.getRootNode())) {
+ handleOneBlock(*node->getBlock());
+ }
+ }
+
+ for (Block &targetBlock : targetRegion)
+ cleanupBlock(&targetBlock);
+}
+
+/// Lowers workshare to a sequence of single-thread regions and parallel loops
+///
+/// For example:
+///
+/// omp.workshare {
+/// %a = fir.allocmem
+/// omp.workshare.loop_wrapper {}
+/// fir.call Assign %b %a
+/// fir.freemem %a
+/// }
+///
+/// becomes
+///
+/// %tmp = fir.alloca
+/// omp.single copyprivate(%tmp) {
+/// %a = fir.allocmem
+/// fir.store %a %tmp
+/// }
+/// %a_reloaded = fir.load %tmp
+/// omp.workshare.loop_wrapper {}
+/// omp.single {
+/// fir.call Assign %b %a_reloaded
+/// fir.freemem %a_reloaded
+/// }
+///
+/// Note that we allocate temporary memory for values in omp.single's which need
+/// to be accessed by all threads and broadcast them using single's copyprivate
+LogicalResult lowerWorkshare(mlir::omp::WorkshareOp wsOp, DominanceInfo &di) {
+ Location loc = wsOp->getLoc();
+ IRMapping rootMapping;
+
+ OpBuilder rootBuilder(wsOp);
+
+ // This operation is just a placeholder which will be erased later. We need it
+ // because our `parallelizeRegion` function works on regions and not blocks.
+ omp::WorkshareOp newOp =
+ rootBuilder.create<omp::WorkshareOp>(loc, omp::WorkshareOperands());
+ if (!wsOp.getNowait())
+ rootBuilder.create<omp::BarrierOp>(loc);
+
+ parallelizeRegion(wsOp.getRegion(), newOp.getRegion(), rootMapping, loc, di);
+
+ if (wsOp.getRegion().getBlocks().size() != 1)
+ return failure();
+
+ // Inline the contents of the placeholder workshare op into its parent block.
+ Block *theBlock = &newOp.getRegion().front();
+ Operation *term = theBlock->getTerminator();
+ Block *parentBlock = wsOp->getBlock();
+ parentBlock->getOperations().splice(newOp->getIterator(),
+ theBlock->getOperations());
+ assert(term->getNumOperands() == 0);
+ term->erase();
+ newOp->erase();
+ wsOp->erase();
+ return success();
+}
+
+class LowerWorksharePass
+ : public flangomp::impl::LowerWorkshareBase<LowerWorksharePass> {
+public:
+ void runOnOperation() override {
+ mlir::DominanceInfo &di = getAnalysis<mlir::DominanceInfo>();
+ getOperation()->walk([&](mlir::omp::WorkshareOp wsOp) {
+ if (failed(lowerWorkshare(wsOp, di)))
+ signalPassFailure();
+ });
+ }
+};
+} // namespace
diff --git a/flang/test/Fir/basic-program.fir b/flang/test/Fir/basic-program.fir
index bca454c13ff9cc..4b18acb7c2b430 100644
--- a/flang/test/Fir/basic-program.fir
+++ b/flang/test/Fir/basic-program.fir
@@ -47,6 +47,7 @@ func.func @_QQmain() {
// PASSES-NEXT: LowerHLFIRIntrinsics
// PASSES-NEXT: BufferizeHLFIR
// PASSES-NEXT: ConvertHLFIRtoFIR
+// PASSES-NEXT: LowerWorkshare
// PASSES-NEXT: CSE
// PASSES-NEXT: (S) 0 num-cse'd - Number of operations CSE'd
// PASSES-NEXT: (S) 0 num-dce'd - Number of operations DCE'd
diff --git a/flang/test/Transforms/OpenMP/lower-workshare.mlir b/flang/test/Transforms/OpenMP/lower-workshare.mlir
new file mode 100644
index 00000000000000..a609ee5d3d6c2a
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare.mlir
@@ -0,0 +1,189 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// checks:
+// nowait on final omp.single
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+ omp.parallel {
+ omp.workshare {
+ %c42 = arith.constant 42 : index
+ %c1_i32 = arith.constant 1 : i32
+ %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+ %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+ %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+ %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+ %true = arith.constant true
+ %c1 = arith.constant 1 : index
+ omp.workshare.loop_wrapper {
+ omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+ %7 = hlfir.designate %1#0 (%arg1) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+ %8 = fir.load %7 : !fir.ref<i32>
+ %9 = arith.subi %8, %c1_i32 : i32
+ %10 = hlfir.designate %3#0 (%arg1) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+ hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+ omp.yield
+ }
+ omp.terminator
+ }
+ %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+ %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+ %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+ hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+ fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+// checks:
+// fir.alloca hoisted out and copyprivate'd
+func.func @wsfunc(%arg0: !fir.ref<!fir.array<42xi32>>) {
+ omp.workshare {
+ %c1_i32 = arith.constant 1 : i32
+ %alloc = fir.alloca i32
+ fir.store %c1_i32 to %alloc : !fir.ref<i32>
+ %c42 = arith.constant 42 : index
+ %0 = fir.shape %c42 : (index) -> !fir.shape<1>
+ %1:2 = hlfir.declare %arg0(%0) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+ %2 = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+ %3:2 = hlfir.declare %2(%0) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+ %true = arith.constant true
+ %c1 = arith.constant 1 : index
+ omp.workshare.loop_wrapper {
+ omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+ %7 = hlfir.designate %1#0 (%arg1) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+ %8 = fir.load %7 : !fir.ref<i32>
+ %ld = fir.load %alloc : !fir.ref<i32>
+ %n8 = arith.subi %8, %ld : i32
+ %9 = arith.subi %n8, %c1_i32 : i32
+ %10 = hlfir.designate %3#0 (%arg1) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+ hlfir.assign %9 to %10 temporary_lhs : i32, !fir.ref<i32>
+ omp.yield
+ }
+ omp.terminator
+ }
+ %4 = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+ %5 = fir.insert_value %4, %true, [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+ %6 = fir.insert_value %5, %3#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+ "test.test1"(%alloc) : (!fir.ref<i32>) -> ()
+ hlfir.assign %3#0 to %1#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+ fir.freemem %3#0 : !fir.heap<!fir.array<42xi32>>
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func.func private @_workshare_copy_heap_42xi32(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK: fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func @wsfunc(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK: omp.parallel {
+// CHECK: %[[VAL_1:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK: omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK: %[[VAL_2:.*]] = arith.constant 42 : index
+// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_2]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK: %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK: fir.store %[[VAL_5]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: %[[VAL_7:.*]] = arith.constant 42 : index
+// CHECK: %[[VAL_8:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_9:.*]] = fir.shape %[[VAL_7]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_10:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_9]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK: %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_11]](%[[VAL_9]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK: %[[VAL_13:.*]] = arith.constant 1 : index
+// CHECK: omp.wsloop {
+// CHECK: omp.loop_nest (%[[VAL_14:.*]]) : index = (%[[VAL_13]]) to (%[[VAL_7]]) inclusive step (%[[VAL_13]]) {
+// CHECK: %[[VAL_15:.*]] = hlfir.designate %[[VAL_10]]#0 (%[[VAL_14]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_16:.*]] = fir.load %[[VAL_15]] : !fir.ref<i32>
+// CHECK: %[[VAL_17:.*]] = arith.subi %[[VAL_16]], %[[VAL_8]] : i32
+// CHECK: %[[VAL_18:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_14]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK: hlfir.assign %[[VAL_17]] to %[[VAL_18]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK: omp.yield
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.single nowait {
+// CHECK: hlfir.assign %[[VAL_12]]#0 to %[[VAL_10]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK: fir.freemem %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.barrier
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func private @_workshare_copy_heap_42xi32(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK: fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func private @_workshare_copy_i32(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<i32>,
+// CHECK-SAME: %[[VAL_1:.*]]: !fir.ref<i32>) {
+// CHECK: %[[VAL_2:.*]] = fir.load %[[VAL_0]] : !fir.ref<i32>
+// CHECK: fir.store %[[VAL_2]] to %[[VAL_1]] : !fir.ref<i32>
+// CHECK: return
+// CHECK: }
+
+// CHECK-LABEL: func.func @wsfunc(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK: %[[VAL_1:.*]] = fir.alloca i32
+// CHECK: %[[VAL_2:.*]] = fir.alloca !fir.heap<!fir.array<42xi32>>
+// CHECK: omp.single copyprivate(%[[VAL_1]] -> @_workshare_copy_i32 : !fir.ref<i32>, %[[VAL_2]] -> @_workshare_copy_heap_42xi32 : !fir.ref<!fir.heap<!fir.array<42xi32>>>) {
+// CHECK: %[[VAL_3:.*]] = arith.constant 1 : i32
+// CHECK: fir.store %[[VAL_3]] to %[[VAL_1]] : !fir.ref<i32>
+// CHECK: %[[VAL_4:.*]] = arith.constant 42 : index
+// CHECK: %[[VAL_5:.*]] = fir.shape %[[VAL_4]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_5]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK: %[[VAL_7:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK: fir.store %[[VAL_7]] to %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK: %[[VAL_8:.*]]:2 = hlfir.declare %[[VAL_7]](%[[VAL_5]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: %[[VAL_9:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_10:.*]] = arith.constant 42 : index
+// CHECK: %[[VAL_11:.*]] = fir.shape %[[VAL_10]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_12:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_11]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK: %[[VAL_13:.*]] = fir.load %[[VAL_2]] : !fir.ref<!fir.heap<!fir.array<42xi32>>>
+// CHECK: %[[VAL_14:.*]]:2 = hlfir.declare %[[VAL_13]](%[[VAL_11]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK: %[[VAL_15:.*]] = arith.constant 1 : index
+// CHECK: omp.wsloop {
+// CHECK: omp.loop_nest (%[[VAL_16:.*]]) : index = (%[[VAL_15]]) to (%[[VAL_10]]) inclusive step (%[[VAL_15]]) {
+// CHECK: %[[VAL_17:.*]] = hlfir.designate %[[VAL_12]]#0 (%[[VAL_16]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_18:.*]] = fir.load %[[VAL_17]] : !fir.ref<i32>
+// CHECK: %[[VAL_19:.*]] = fir.load %[[VAL_1]] : !fir.ref<i32>
+// CHECK: %[[VAL_20:.*]] = arith.subi %[[VAL_18]], %[[VAL_19]] : i32
+// CHECK: %[[VAL_21:.*]] = arith.subi %[[VAL_20]], %[[VAL_9]] : i32
+// CHECK: %[[VAL_22:.*]] = hlfir.designate %[[VAL_14]]#0 (%[[VAL_16]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK: hlfir.assign %[[VAL_21]] to %[[VAL_22]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK: omp.yield
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.single nowait {
+// CHECK: "test.test1"(%[[VAL_1]]) : (!fir.ref<i32>) -> ()
+// CHECK: hlfir.assign %[[VAL_14]]#0 to %[[VAL_12]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK: fir.freemem %[[VAL_14]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.barrier
+// CHECK: return
+// CHECK: }
+
diff --git a/flang/test/Transforms/OpenMP/lower-workshare2.mlir b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
new file mode 100644
index 00000000000000..940662e0bdccc2
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare2.mlir
@@ -0,0 +1,23 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// Check that we correctly handle nowait
+
+// CHECK-LABEL: func.func @nonowait
+func.func @nonowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
+ // CHECK: omp.barrier
+ omp.workshare {
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+// CHECK-LABEL: func.func @nowait
+func.func @nowait(%arg0: !fir.ref<!fir.array<42xi32>>) {
+ // CHECK-NOT: omp.barrier
+ omp.workshare nowait {
+ omp.terminator
+ }
+ return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workshare3.mlir b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
new file mode 100644
index 00000000000000..5a3d583527fddb
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare3.mlir
@@ -0,0 +1,74 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+
+// Check if we store the correct values
+
+func.func @wsfunc() {
+ omp.parallel {
+ // CHECK: fir.alloca
+ // CHECK: fir.alloca
+ // CHECK: fir.alloca
+ // CHECK: fir.alloca
+ // CHECK: fir.alloca
+ // CHECK-NOT: fir.alloca
+ omp.workshare {
+
+ %t1 = "test.test1"() : () -> i32
+ // CHECK: %[[T1:.*]] = "test.test1"
+ // CHECK: fir.store %[[T1]]
+ %t2 = "test.test2"() : () -> i32
+ // CHECK: %[[T2:.*]] = "test.test2"
+ // CHECK: fir.store %[[T2]]
+ %t3 = "test.test3"() : () -> i32
+ // CHECK: %[[T3:.*]] = "test.test3"
+ // CHECK-NOT: fir.store %[[T3]]
+ %t4 = "test.test4"() : () -> i32
+ // CHECK: %[[T4:.*]] = "test.test4"
+ // CHECK: fir.store %[[T4]]
+ %t5 = "test.test5"() : () -> i32
+ // CHECK: %[[T5:.*]] = "test.test5"
+ // CHECK: fir.store %[[T5]]
+ %t6 = "test.test6"() : () -> i32
+ // CHECK: %[[T6:.*]] = "test.test6"
+ // CHECK-NOT: fir.store %[[T6]]
+
+
+ "test.test1"(%t1) : (i32) -> ()
+ "test.test1"(%t2) : (i32) -> ()
+ "test.test1"(%t3) : (i32) -> ()
+
+ %true = arith.constant true
+ fir.if %true {
+ "test.test2"(%t3) : (i32) -> ()
+ }
+
+ %c1_i32 = arith.constant 1 : i32
+
+ %t5_pure_use = arith.addi %t5, %c1_i32 : i32
+
+ %t6_mem_effect_use = "test.test8"(%t6) : (i32) -> i32
+ // CHECK: %[[T6_USE:.*]] = "test.test8"
+ // CHECK: fir.store %[[T6_USE]]
+
+ %c42 = arith.constant 42 : index
+ %c1 = arith.constant 1 : index
+ omp.workshare.loop_wrapper {
+ omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+ "test.test10"(%t1) : (i32) -> ()
+ "test.test10"(%t5_pure_use) : (i32) -> ()
+ "test.test10"(%t6_mem_effect_use) : (i32) -> ()
+ omp.yield
+ }
+ omp.terminator
+ }
+
+ "test.test10"(%t2) : (i32) -> ()
+ fir.if %true {
+ "test.test10"(%t4) : (i32) -> ()
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workshare4.mlir b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
new file mode 100644
index 00000000000000..02fe90097008db
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare4.mlir
@@ -0,0 +1,59 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// Check that we cleanup unused pure operations from the parallel and single
+// regions
+
+// CHECK-LABEL: func.func @wsfunc() {
+// CHECK: %[[VAL_0:.*]] = fir.alloca i32
+// CHECK: omp.parallel {
+// CHECK: omp.single {
+// CHECK: %[[VAL_1:.*]] = "test.test1"() : () -> i32
+// CHECK: %[[VAL_2:.*]] = arith.constant 2 : index
+// CHECK: %[[VAL_3:.*]] = arith.constant 3 : index
+// CHECK: %[[VAL_4:.*]] = arith.addi %[[VAL_2]], %[[VAL_3]] : index
+// CHECK: "test.test3"(%[[VAL_4]]) : (index) -> ()
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: %[[VAL_5:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_6:.*]] = arith.constant 42 : index
+// CHECK: omp.wsloop nowait {
+// CHECK: omp.loop_nest (%[[VAL_7:.*]]) : index = (%[[VAL_5]]) to (%[[VAL_6]]) inclusive step (%[[VAL_5]]) {
+// CHECK: "test.test2"() : () -> ()
+// CHECK: omp.yield
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.barrier
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func.func @wsfunc() {
+ %a = fir.alloca i32
+ omp.parallel {
+ omp.workshare {
+ %t1 = "test.test1"() : () -> i32
+
+ %c1 = arith.constant 1 : index
+ %c42 = arith.constant 42 : index
+
+ %c2 = arith.constant 2 : index
+ %c3 = arith.constant 3 : index
+ %add = arith.addi %c2, %c3 : index
+ "test.test3"(%add) : (index) -> ()
+
+ omp.workshare.loop_wrapper {
+ omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+ "test.test2"() : () -> ()
+ omp.yield
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+
diff --git a/flang/test/Transforms/OpenMP/lower-workshare5.mlir b/flang/test/Transforms/OpenMP/lower-workshare5.mlir
new file mode 100644
index 00000000000000..177f8aa8f86c7c
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare5.mlir
@@ -0,0 +1,42 @@
+// XFAIL: *
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// TODO we can lower these but we have no guarantee that the parent of
+// omp.workshare supports multi-block regions, thus we fail for now.
+
+func.func @wsfunc() {
+ %a = fir.alloca i32
+ omp.parallel {
+ omp.workshare {
+ ^bb1:
+ %c1 = arith.constant 1 : i32
+ cf.br ^bb3(%c1: i32)
+ ^bb3(%arg1: i32):
+ "test.test2"(%arg1) : (i32) -> ()
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// -----
+
+func.func @wsfunc() {
+ %a = fir.alloca i32
+ omp.parallel {
+ omp.workshare {
+ ^bb1:
+ %c1 = arith.constant 1 : i32
+ cf.br ^bb3(%c1: i32)
+ ^bb2:
+ "test.test2"(%r) : (i32) -> ()
+ omp.terminator
+ ^bb3(%arg1: i32):
+ %r = "test.test2"(%arg1) : (i32) -> i32
+ cf.br ^bb2
+ }
+ omp.terminator
+ }
+ return
+}
diff --git a/flang/test/Transforms/OpenMP/lower-workshare6.mlir b/flang/test/Transforms/OpenMP/lower-workshare6.mlir
new file mode 100644
index 00000000000000..48379470e92562
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/lower-workshare6.mlir
@@ -0,0 +1,51 @@
+// RUN: fir-opt --split-input-file --lower-workshare --allow-unregistered-dialect %s | FileCheck %s
+
+// Checks that the omp.workshare.loop_wrapper binds to the correct omp.workshare
+
+func.func @wsfunc() {
+ %c1 = arith.constant 1 : index
+ %c42 = arith.constant 42 : index
+ omp.parallel {
+ omp.workshare nowait {
+ omp.parallel {
+ omp.workshare nowait {
+ omp.workshare.loop_wrapper {
+ omp.loop_nest (%arg1) : index = (%c1) to (%c42) inclusive step (%c1) {
+ "test.test2"() : () -> ()
+ omp.yield
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: func.func @wsfunc() {
+// CHECK: %[[VAL_0:.*]] = arith.constant 1 : index
+// CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
+// CHECK: omp.parallel {
+// CHECK: omp.single nowait {
+// CHECK: omp.parallel {
+// CHECK: omp.wsloop nowait {
+// CHECK: omp.loop_nest (%[[VAL_2:.*]]) : index = (%[[VAL_0]]) to (%[[VAL_1]]) inclusive step (%[[VAL_0]]) {
+// CHECK: "test.test2"() : () -> ()
+// CHECK: omp.yield
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: return
+// CHECK: }
+
diff --git a/flang/tools/bbc/bbc.cpp b/flang/tools/bbc/bbc.cpp
index 736d68219581dd..1a7dac1b76bc20 100644
--- a/flang/tools/bbc/bbc.cpp
+++ b/flang/tools/bbc/bbc.cpp
@@ -440,7 +440,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
if (emitFIR && useHLFIR) {
// lower HLFIR to FIR
- fir::createHLFIRToFIRPassPipeline(pm, llvm::OptimizationLevel::O2);
+ fir::createHLFIRToFIRPassPipeline(pm, enableOpenMP,
+ llvm::OptimizationLevel::O2);
if (mlir::failed(pm.run(mlirModule))) {
llvm::errs() << "FATAL: lowering from HLFIR to FIR failed";
return mlir::failure();
@@ -455,6 +456,8 @@ static llvm::LogicalResult convertFortranSourceToMLIR(
// Add O2 optimizer pass pipeline.
MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
+ if (enableOpenMP)
+ config.EnableOpenMP = true;
config.NSWOnLoopVarInc = setNSW;
fir::registerDefaultInlinerPass(config);
fir::createDefaultFIROptimizerPassPipeline(pm, config);
diff --git a/flang/tools/tco/tco.cpp b/flang/tools/tco/tco.cpp
index a8c64333109aeb..06892cdc3f6a80 100644
--- a/flang/tools/tco/tco.cpp
+++ b/flang/tools/tco/tco.cpp
@@ -138,6 +138,7 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
return mlir::failure();
} else {
MLIRToLLVMPassPipelineConfig config(llvm::OptimizationLevel::O2);
+ config.EnableOpenMP = true; // assume the input contains OpenMP
config.AliasAnalysis = true; // enabled when optimizing for speed
if (codeGenLLVM) {
// Run only CodeGen passes.
>From a942ae87d9309e2cf1b26e319a1fc5c4eb1f86fc Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Sun, 4 Aug 2024 17:33:52 +0900
Subject: [PATCH 6/6] Add workshare loop wrapper lowerings
Bufferize test
Bufferize test
Bufferize test
Add test for should use workshare lowering
---
.../HLFIR/Transforms/BufferizeHLFIR.cpp | 6 +-
.../Transforms/OptimizedBufferization.cpp | 10 +-
flang/test/HLFIR/bufferize-workshare.fir | 58 ++++++++
.../OpenMP/should-use-workshare-lowering.mlir | 140 ++++++++++++++++++
4 files changed, 209 insertions(+), 5 deletions(-)
create mode 100644 flang/test/HLFIR/bufferize-workshare.fir
create mode 100644 flang/test/Transforms/OpenMP/should-use-workshare-lowering.mlir
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index b608677c526310..1848dbe2c7a2c2 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -26,12 +26,13 @@
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Pass/PassManager.h"
#include "mlir/Transforms/DialectConversion.h"
-#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
#include "llvm/ADT/TypeSwitch.h"
namespace hlfir {
@@ -792,7 +793,8 @@ struct ElementalOpConversion
// Generate a loop nest looping around the fir.elemental shape and clone
// fir.elemental region inside the inner loop.
hlfir::LoopNest loopNest =
- hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
+ hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered(),
+ flangomp::shouldUseWorkshareLowering(elemental));
auto insPt = builder.saveInsertionPoint();
builder.setInsertionPointToStart(loopNest.body);
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index 3a0a98dc594463..f014724861e333 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -20,6 +20,7 @@
#include "flang/Optimizer/HLFIR/HLFIRDialect.h"
#include "flang/Optimizer/HLFIR/HLFIROps.h"
#include "flang/Optimizer/HLFIR/Passes.h"
+#include "flang/Optimizer/OpenMP/Passes.h"
#include "flang/Optimizer/Transforms/Utils.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/Dominance.h"
@@ -482,7 +483,8 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
// Generate a loop nest looping around the hlfir.elemental shape and clone
// hlfir.elemental region inside the inner loop
hlfir::LoopNest loopNest =
- hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
+ hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered(),
+ flangomp::shouldUseWorkshareLowering(elemental));
builder.setInsertionPointToStart(loopNest.body);
auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
loopNest.oneBasedIndices);
@@ -553,7 +555,8 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
llvm::SmallVector<mlir::Value> extents =
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
- hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
+ hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
+ flangomp::shouldUseWorkshareLowering(assign));
builder.setInsertionPointToStart(loopNest.body);
auto arrayElement =
hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
@@ -648,7 +651,8 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
llvm::SmallVector<mlir::Value> extents =
hlfir::getIndexExtents(loc, builder, shape);
hlfir::LoopNest loopNest =
- hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
+ hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true,
+ flangomp::shouldUseWorkshareLowering(assign));
builder.setInsertionPointToStart(loopNest.body);
auto rhsArrayElement =
hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
diff --git a/flang/test/HLFIR/bufferize-workshare.fir b/flang/test/HLFIR/bufferize-workshare.fir
new file mode 100644
index 00000000000000..9b7341ae43398a
--- /dev/null
+++ b/flang/test/HLFIR/bufferize-workshare.fir
@@ -0,0 +1,58 @@
+// RUN: fir-opt --bufferize-hlfir %s | FileCheck %s
+
+// CHECK-LABEL: func.func @simple(
+// CHECK-SAME: %[[VAL_0:.*]]: !fir.ref<!fir.array<42xi32>>) {
+// CHECK: omp.parallel {
+// CHECK: omp.workshare {
+// CHECK: %[[VAL_1:.*]] = arith.constant 42 : index
+// CHECK: %[[VAL_2:.*]] = arith.constant 1 : i32
+// CHECK: %[[VAL_3:.*]] = fir.shape %[[VAL_1]] : (index) -> !fir.shape<1>
+// CHECK: %[[VAL_4:.*]]:2 = hlfir.declare %[[VAL_0]](%[[VAL_3]]) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+// CHECK: %[[VAL_5:.*]] = fir.allocmem !fir.array<42xi32> {bindc_name = ".tmp.array", uniq_name = ""}
+// CHECK: %[[VAL_6:.*]]:2 = hlfir.declare %[[VAL_5]](%[[VAL_3]]) {uniq_name = ".tmp.array"} : (!fir.heap<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.heap<!fir.array<42xi32>>, !fir.heap<!fir.array<42xi32>>)
+// CHECK: %[[VAL_7:.*]] = arith.constant true
+// CHECK: %[[VAL_8:.*]] = arith.constant 1 : index
+// CHECK: omp.workshare.loop_wrapper {
+// CHECK: omp.loop_nest (%[[VAL_9:.*]]) : index = (%[[VAL_8]]) to (%[[VAL_1]]) inclusive step (%[[VAL_8]]) {
+// CHECK: %[[VAL_10:.*]] = hlfir.designate %[[VAL_4]]#0 (%[[VAL_9]]) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK: %[[VAL_11:.*]] = fir.load %[[VAL_10]] : !fir.ref<i32>
+// CHECK: %[[VAL_12:.*]] = arith.subi %[[VAL_11]], %[[VAL_2]] : i32
+// CHECK: %[[VAL_13:.*]] = hlfir.designate %[[VAL_6]]#0 (%[[VAL_9]]) : (!fir.heap<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+// CHECK: hlfir.assign %[[VAL_12]] to %[[VAL_13]] temporary_lhs : i32, !fir.ref<i32>
+// CHECK: omp.yield
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: %[[VAL_14:.*]] = fir.undefined tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK: %[[VAL_15:.*]] = fir.insert_value %[[VAL_14]], %[[VAL_7]], [1 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, i1) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK: %[[VAL_16:.*]] = fir.insert_value %[[VAL_15]], %[[VAL_6]]#0, [0 : index] : (tuple<!fir.heap<!fir.array<42xi32>>, i1>, !fir.heap<!fir.array<42xi32>>) -> tuple<!fir.heap<!fir.array<42xi32>>, i1>
+// CHECK: hlfir.assign %[[VAL_6]]#0 to %[[VAL_4]]#0 : !fir.heap<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>
+// CHECK: fir.freemem %[[VAL_6]]#0 : !fir.heap<!fir.array<42xi32>>
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
+// CHECK: return
+// CHECK: }
+func.func @simple(%arg: !fir.ref<!fir.array<42xi32>>) {
+ omp.parallel {
+ omp.workshare {
+ %c42 = arith.constant 42 : index
+ %c1_i32 = arith.constant 1 : i32
+ %shape = fir.shape %c42 : (index) -> !fir.shape<1>
+ %array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+ %elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+ ^bb0(%i: index):
+ %ref = hlfir.designate %array#0 (%i) : (!fir.ref<!fir.array<42xi32>>, index) -> !fir.ref<i32>
+ %val = fir.load %ref : !fir.ref<i32>
+ %sub = arith.subi %val, %c1_i32 : i32
+ hlfir.yield_element %sub : i32
+ }
+ hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+ hlfir.destroy %elemental : !hlfir.expr<42xi32>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
diff --git a/flang/test/Transforms/OpenMP/should-use-workshare-lowering.mlir b/flang/test/Transforms/OpenMP/should-use-workshare-lowering.mlir
new file mode 100644
index 00000000000000..229fe592a02b9b
--- /dev/null
+++ b/flang/test/Transforms/OpenMP/should-use-workshare-lowering.mlir
@@ -0,0 +1,140 @@
+// RUN: fir-opt --bufferize-hlfir %s | FileCheck %s
+
+// Checks that we correctly identify when to use the lowering to
+// omp.workshare.loop_wrapper
+
+// CHECK-LABEL: @should_parallelize_0
+// CHECK: omp.workshare.loop_wrapper
+func.func @should_parallelize_0(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
+ omp.workshare {
+ %c42 = arith.constant 42 : index
+ %c1_i32 = arith.constant 1 : i32
+ %shape = fir.shape %c42 : (index) -> !fir.shape<1>
+ %array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+ %elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+ ^bb0(%i: index):
+ hlfir.yield_element %c1_i32 : i32
+ }
+ hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+ hlfir.destroy %elemental : !hlfir.expr<42xi32>
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: @should_parallelize_1
+// CHECK: omp.workshare.loop_wrapper
+func.func @should_parallelize_1(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
+ omp.parallel {
+ omp.workshare {
+ %c42 = arith.constant 42 : index
+ %c1_i32 = arith.constant 1 : i32
+ %shape = fir.shape %c42 : (index) -> !fir.shape<1>
+ %array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+ %elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+ ^bb0(%i: index):
+ hlfir.yield_element %c1_i32 : i32
+ }
+ hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+ hlfir.destroy %elemental : !hlfir.expr<42xi32>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+
+// CHECK-LABEL: @should_not_parallelize_0
+// CHECK-NOT: omp.workshare.loop_wrapper
+func.func @should_not_parallelize_0(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
+ omp.workshare {
+ omp.single {
+ %c42 = arith.constant 42 : index
+ %c1_i32 = arith.constant 1 : i32
+ %shape = fir.shape %c42 : (index) -> !fir.shape<1>
+ %array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+ %elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+ ^bb0(%i: index):
+ hlfir.yield_element %c1_i32 : i32
+ }
+ hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+ hlfir.destroy %elemental : !hlfir.expr<42xi32>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: @should_not_parallelize_1
+// CHECK-NOT: omp.workshare.loop_wrapper
+func.func @should_not_parallelize_1(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
+ omp.workshare {
+ omp.critical {
+ %c42 = arith.constant 42 : index
+ %c1_i32 = arith.constant 1 : i32
+ %shape = fir.shape %c42 : (index) -> !fir.shape<1>
+ %array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+ %elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+ ^bb0(%i: index):
+ hlfir.yield_element %c1_i32 : i32
+ }
+ hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+ hlfir.destroy %elemental : !hlfir.expr<42xi32>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: @should_not_parallelize_2
+// CHECK-NOT: omp.workshare.loop_wrapper
+func.func @should_not_parallelize_2(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
+ omp.workshare {
+ omp.parallel {
+ %c42 = arith.constant 42 : index
+ %c1_i32 = arith.constant 1 : i32
+ %shape = fir.shape %c42 : (index) -> !fir.shape<1>
+ %array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+ %elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+ ^bb0(%i: index):
+ hlfir.yield_element %c1_i32 : i32
+ }
+ hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+ hlfir.destroy %elemental : !hlfir.expr<42xi32>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
+
+// CHECK-LABEL: @should_not_parallelize_3
+// CHECK-NOT: omp.workshare.loop_wrapper
+func.func @should_not_parallelize_3(%arg: !fir.ref<!fir.array<42xi32>>, %idx : index) {
+ omp.workshare {
+ omp.parallel {
+ omp.workshare {
+ omp.parallel {
+ %c42 = arith.constant 42 : index
+ %c1_i32 = arith.constant 1 : i32
+ %shape = fir.shape %c42 : (index) -> !fir.shape<1>
+ %array:2 = hlfir.declare %arg(%shape) {uniq_name = "array"} : (!fir.ref<!fir.array<42xi32>>, !fir.shape<1>) -> (!fir.ref<!fir.array<42xi32>>, !fir.ref<!fir.array<42xi32>>)
+ %elemental = hlfir.elemental %shape unordered : (!fir.shape<1>) -> !hlfir.expr<42xi32> {
+ ^bb0(%i: index):
+ hlfir.yield_element %c1_i32 : i32
+ }
+ hlfir.assign %elemental to %array#0 : !hlfir.expr<42xi32>, !fir.ref<!fir.array<42xi32>>
+ hlfir.destroy %elemental : !hlfir.expr<42xi32>
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ omp.terminator
+ }
+ return
+}
More information about the llvm-branch-commits
mailing list