[llvm-branch-commits] [flang] [mlir] [MLIR][Flang][OpenMP] Make omp.wsloop into a loop wrapper (PR #88403)
via llvm-branch-commits
llvm-branch-commits at lists.llvm.org
Thu Apr 11 08:56:27 PDT 2024
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
@llvm/pr-subscribers-mlir-openmp
Author: Sergio Afonso (skatrak)
<details>
<summary>Changes</summary>
This patch updates the definition of `omp.wsloop` to enforce the restrictions of a wrapper operation. Given the widespread use of this operation, the changes introduced in this patch are several:
- Update the MLIR definition of the `omp.wsloop`, as well as parser/printer, builder and verifier.
- Update verifiers for `omp.ordered.region`, `omp.cancel` and `omp.cancellation_point` to correctly check for a parent `omp.wsloop`.
- Update MLIR to LLVM IR translation of `omp.wsloop` to keep working after the change in representation. Another patch should be created to reduce the current code duplication between `omp.wsloop` and `omp.simd` after introducing a common `omp.loop_nest` operation.
- Update the `scf.parallel` lowering pass to OpenMP to produce the new expected representation.
- Update flang lowering to implement `omp.wsloop` representation changes, including changes to `lastprivate`, and `reduction` handling to avoid adding operations into a wrapper and attach entry block arguments to the right operation.
- Fix unit tests broken due to the representation change.
---
Patch is 758.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88403.diff
110 Files Affected:
- (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+27-21)
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+43-74)
- (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+61-47)
- (modified) flang/test/Lower/OpenMP/FIR/copyin.f90 (+11-5)
- (modified) flang/test/Lower/OpenMP/FIR/lastprivate-commonblock.f90 (+4-1)
- (modified) flang/test/Lower/OpenMP/FIR/location.f90 (+10-7)
- (modified) flang/test/Lower/OpenMP/FIR/parallel-lastprivate-clause-scalar.f90 (+36-12)
- (modified) flang/test/Lower/OpenMP/FIR/parallel-private-clause-fixes.f90 (+26-23)
- (modified) flang/test/Lower/OpenMP/FIR/parallel-private-clause.f90 (+60-54)
- (modified) flang/test/Lower/OpenMP/FIR/parallel-wsloop-firstpriv.f90 (+10-2)
- (modified) flang/test/Lower/OpenMP/FIR/parallel-wsloop.f90 (+74-54)
- (modified) flang/test/Lower/OpenMP/FIR/stop-stmt-in-region.f90 (+21-18)
- (modified) flang/test/Lower/OpenMP/FIR/target.f90 (+4-1)
- (modified) flang/test/Lower/OpenMP/FIR/unstructured.f90 (+110-89)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-chunks.f90 (+28-19)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-collapse.f90 (+16-13)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-monotonic.f90 (+17-13)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-nonmonotonic.f90 (+17-14)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-ordered.f90 (+12-6)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add-byref.f90 (+106-85)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add.f90 (+106-85)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-iand-byref.f90 (+3-1)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-iand.f90 (+3-1)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ieor-byref.f90 (+3-1)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ieor.f90 (+3-1)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ior-byref.f90 (+3-1)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ior.f90 (+3-1)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv-byref.f90 (+75-69)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv.f90 (+75-69)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv-byref.f90 (+75-69)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv.f90 (+75-69)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-max-byref.f90 (+18-13)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-max.f90 (+18-13)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-min-byref.f90 (+18-14)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-min.f90 (+18-14)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-simd.f90 (+16-13)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-variable.f90 (+93-79)
- (modified) flang/test/Lower/OpenMP/FIR/wsloop.f90 (+36-30)
- (modified) flang/test/Lower/OpenMP/Todo/omp-default-clause-inner-loop.f90 (+4-1)
- (modified) flang/test/Lower/OpenMP/copyin.f90 (+17-11)
- (modified) flang/test/Lower/OpenMP/default-clause-byref.f90 (+4-1)
- (modified) flang/test/Lower/OpenMP/default-clause.f90 (+4-1)
- (modified) flang/test/Lower/OpenMP/hlfir-wsloop.f90 (+7-5)
- (modified) flang/test/Lower/OpenMP/lastprivate-commonblock.f90 (+31-28)
- (modified) flang/test/Lower/OpenMP/lastprivate-iv.f90 (+48-42)
- (modified) flang/test/Lower/OpenMP/location.f90 (+10-7)
- (modified) flang/test/Lower/OpenMP/parallel-lastprivate-clause-scalar.f90 (+36-12)
- (modified) flang/test/Lower/OpenMP/parallel-private-clause-fixes.f90 (+26-23)
- (modified) flang/test/Lower/OpenMP/parallel-private-clause.f90 (+56-50)
- (modified) flang/test/Lower/OpenMP/parallel-wsloop-firstpriv.f90 (+8-2)
- (modified) flang/test/Lower/OpenMP/parallel-wsloop.f90 (+79-59)
- (modified) flang/test/Lower/OpenMP/stop-stmt-in-region.f90 (+21-18)
- (modified) flang/test/Lower/OpenMP/target.f90 (+4-1)
- (modified) flang/test/Lower/OpenMP/unstructured.f90 (+110-89)
- (modified) flang/test/Lower/OpenMP/wsloop-chunks.f90 (+28-19)
- (modified) flang/test/Lower/OpenMP/wsloop-collapse.f90 (+16-13)
- (modified) flang/test/Lower/OpenMP/wsloop-monotonic.f90 (+10-8)
- (modified) flang/test/Lower/OpenMP/wsloop-nonmonotonic.f90 (+11-8)
- (modified) flang/test/Lower/OpenMP/wsloop-ordered.f90 (+12-6)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-byref.f90 (+120-99)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir-byref.f90 (+10-8)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90 (+10-8)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-add.f90 (+120-99)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-array.f90 (+19-16)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-array2.f90 (+27-24)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-iand-byref.f90 (+13-11)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-iand.f90 (+13-11)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-ieor-byref.f90 (+4-2)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-ieor.f90 (+4-2)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-ior-byref.f90 (+13-11)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-ior.f90 (+13-11)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-and-byref.f90 (+70-64)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 (+70-64)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv-byref.f90 (+70-64)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv.f90 (+70-64)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv-byref.f90 (+70-64)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv.f90 (+70-64)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-or-byref.f90 (+70-64)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-or.f90 (+70-64)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-byref.f90 (+48-42)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir-byref.f90 (+14-12)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir.f90 (+14-12)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-max.f90 (+48-42)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-min-byref.f90 (+49-43)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-min.f90 (+49-43)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-min2.f90 (+9-7)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-mul-byref.f90 (+113-99)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-mul.f90 (+113-99)
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-multi.f90 (+26-23)
- (modified) flang/test/Lower/OpenMP/wsloop-simd.f90 (+16-13)
- (modified) flang/test/Lower/OpenMP/wsloop-unstructured.f90 (+21-18)
- (modified) flang/test/Lower/OpenMP/wsloop-variable.f90 (+91-76)
- (modified) flang/test/Lower/OpenMP/wsloop.f90 (+39-33)
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+21-35)
- (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+40-12)
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+47-85)
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+32-34)
- (modified) mlir/test/CAPI/execution_engine.c (+5-2)
- (modified) mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir (+36-28)
- (modified) mlir/test/Conversion/SCFToOpenMP/reductions.mlir (+4)
- (modified) mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir (+23-8)
- (modified) mlir/test/Dialect/LLVMIR/legalize-for-export.mlir (+11-8)
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+167-90)
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+326-236)
- (modified) mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir (+7-4)
- (modified) mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir (+10-7)
- (modified) mlir/test/Target/LLVMIR/omptarget-wsloop.mlir (+12-6)
- (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+404-337)
- (modified) mlir/test/Target/LLVMIR/openmp-nested.mlir (+18-12)
- (modified) mlir/test/Target/LLVMIR/openmp-reduction.mlir (+63-50)
``````````diff
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index e114ab9f4548ab..645c351ac6c08c 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -133,8 +133,14 @@ void DataSharingProcessor::insertBarrier() {
}
void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
+ mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+ mlir::omp::LoopNestOp loopOp;
+ if (auto wrapper = mlir::dyn_cast<mlir::omp::LoopWrapperInterface>(op))
+ loopOp = wrapper.isWrapper()
+ ? mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop())
+ : nullptr;
+
bool cmpCreated = false;
- mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint();
for (const omp::Clause &clause : clauses) {
if (clause.id != llvm::omp::OMPC_lastprivate)
continue;
@@ -213,18 +219,20 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
// Update the original variable just before exiting the worksharing
// loop. Conversion as follows:
//
- // omp.wsloop {
- // omp.wsloop { ...
- // ... store
- // store ===> %v = arith.addi %iv, %step
- // omp.yield %cmp = %step < 0 ? %v < %ub : %v > %ub
- // } fir.if %cmp {
- // fir.store %v to %loopIV
- // ^%lpv_update_blk:
- // }
- // omp.yield
- // }
- //
+ // omp.wsloop { omp.wsloop {
+ // omp.loop_nest { omp.loop_nest {
+ // ... ...
+ // store ===> store
+ // omp.yield %v = arith.addi %iv, %step
+ // } %cmp = %step < 0 ? %v < %ub : %v > %ub
+ // omp.terminator fir.if %cmp {
+ // } fir.store %v to %loopIV
+ // ^%lpv_update_blk:
+ // }
+ // omp.yield
+ // }
+ // omp.terminator
+ // }
// Only generate the compare once in presence of multiple LastPrivate
// clauses.
@@ -232,14 +240,13 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
continue;
cmpCreated = true;
- mlir::Location loc = op->getLoc();
- mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
+ mlir::Location loc = loopOp.getLoc();
+ mlir::Operation *lastOper = loopOp.getRegion().back().getTerminator();
firOpBuilder.setInsertionPoint(lastOper);
- mlir::Value iv = op->getRegion(0).front().getArguments()[0];
- mlir::Value ub =
- mlir::dyn_cast<mlir::omp::WsloopOp>(op).getUpperBound()[0];
- mlir::Value step = mlir::dyn_cast<mlir::omp::WsloopOp>(op).getStep()[0];
+ mlir::Value iv = loopOp.getIVs()[0];
+ mlir::Value ub = loopOp.getUpperBound()[0];
+ mlir::Value step = loopOp.getStep()[0];
// v = iv + step
// cmp = step < 0 ? v < ub : v > ub
@@ -258,7 +265,7 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
assert(loopIV && "loopIV was not set");
- firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
+ firOpBuilder.create<fir::StoreOp>(loopOp.getLoc(), v, loopIV);
lastPrivIP = firOpBuilder.saveInsertionPoint();
} else {
TODO(converter.getCurrentLocation(),
@@ -266,7 +273,6 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
"simd/worksharing-loop");
}
}
- firOpBuilder.restoreInsertionPoint(localInsPt);
}
void DataSharingProcessor::collectSymbols(
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 1800fcb19dcd2e..b21351382b6bdf 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1626,7 +1626,9 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
static llvm::SmallVector<const Fortran::semantics::Symbol *>
genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
mlir::Location &loc,
- llvm::ArrayRef<const Fortran::semantics::Symbol *> args) {
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> args,
+ llvm::ArrayRef<const Fortran::semantics::Symbol *> wrapperSyms = {},
+ llvm::ArrayRef<mlir::BlockArgument> wrapperArgs = {}) {
fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
auto ®ion = op->getRegion(0);
@@ -1637,6 +1639,14 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
llvm::SmallVector<mlir::Location> locs(args.size(), loc);
firOpBuilder.createBlock(®ion, {}, tiv, locs);
+
+ // Bind the entry block arguments of parent wrappers to the corresponding
+ // symbols. Do it here so that any hlfir.declare operations created as a
+ // result are inserted inside of the omp.loop_nest rather than the wrapper
+ // operations.
+ for (auto [arg, prv] : llvm::zip_equal(wrapperSyms, wrapperArgs))
+ converter.bindSymbol(*arg, prv);
+
// The argument is not currently in memory, so make a temporary for the
// argument, and store it there, then bind that location to the argument.
mlir::Operation *storeOp = nullptr;
@@ -1650,58 +1660,6 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
return llvm::SmallVector<const Fortran::semantics::Symbol *>(args);
}
-static llvm::SmallVector<const Fortran::semantics::Symbol *>
-genLoopAndReductionVars(
- mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
- mlir::Location &loc,
- llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs,
- llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
- llvm::ArrayRef<mlir::Type> reductionTypes) {
- fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
- llvm::SmallVector<mlir::Type> blockArgTypes;
- llvm::SmallVector<mlir::Location> blockArgLocs;
- blockArgTypes.reserve(loopArgs.size() + reductionArgs.size());
- blockArgLocs.reserve(blockArgTypes.size());
- mlir::Block *entryBlock;
-
- if (loopArgs.size()) {
- std::size_t loopVarTypeSize = 0;
- for (const Fortran::semantics::Symbol *arg : loopArgs)
- loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
- mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
- std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(),
- loopVarType);
- std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc);
- }
- if (reductionArgs.size()) {
- llvm::copy(reductionTypes, std::back_inserter(blockArgTypes));
- std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc);
- }
- entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes,
- blockArgLocs);
- // The argument is not currently in memory, so make a temporary for the
- // argument, and store it there, then bind that location to the argument.
- if (loopArgs.size()) {
- mlir::Operation *storeOp = nullptr;
- for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
- mlir::Value indexVal =
- fir::getBase(op->getRegion(0).front().getArgument(argIndex));
- storeOp =
- createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
- }
- firOpBuilder.setInsertionPointAfter(storeOp);
- }
- // Bind the reduction arguments to their block arguments
- for (auto [arg, prv] : llvm::zip_equal(
- reductionArgs,
- llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) {
- converter.bindSymbol(*arg, prv);
- }
-
- return llvm::SmallVector<const Fortran::semantics::Symbol *>(loopArgs);
-}
-
static void createSimd(Fortran::lower::AbstractConverter &converter,
Fortran::semantics::SemanticsContext &semaCtx,
Fortran::lower::pft::Evaluation &eval,
@@ -1797,28 +1755,26 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
if (ReductionProcessor::doReductionByRef(reductionVars))
byrefOperand = firOpBuilder.getUnitAttr();
- auto wsLoopOp = firOpBuilder.create<mlir::omp::WsloopOp>(
- loc, lowerBound, upperBound, step, linearVars, linearStepVars,
- reductionVars,
+ auto wsloopOp = firOpBuilder.create<mlir::omp::WsloopOp>(
+ loc, linearVars, linearStepVars, reductionVars,
reductionDeclSymbols.empty()
? nullptr
: mlir::ArrayAttr::get(firOpBuilder.getContext(),
reductionDeclSymbols),
scheduleValClauseOperand, scheduleChunkClauseOperand,
- /*schedule_modifiers=*/nullptr,
- /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
- orderedClauseOperand, orderClauseOperand,
- /*inclusive=*/firOpBuilder.getUnitAttr());
+ /*schedule_modifiers=*/nullptr, /*simd_modifier=*/nullptr,
+ nowaitClauseOperand, byrefOperand, orderedClauseOperand,
+ orderClauseOperand);
// Handle attribute based clauses.
if (cp.processOrdered(orderedClauseOperand))
- wsLoopOp.setOrderedValAttr(orderedClauseOperand);
+ wsloopOp.setOrderedValAttr(orderedClauseOperand);
if (cp.processSchedule(scheduleValClauseOperand, scheduleModClauseOperand,
scheduleSimdClauseOperand)) {
- wsLoopOp.setScheduleValAttr(scheduleValClauseOperand);
- wsLoopOp.setScheduleModifierAttr(scheduleModClauseOperand);
- wsLoopOp.setSimdModifierAttr(scheduleSimdClauseOperand);
+ wsloopOp.setScheduleValAttr(scheduleValClauseOperand);
+ wsloopOp.setScheduleModifierAttr(scheduleModClauseOperand);
+ wsloopOp.setSimdModifierAttr(scheduleSimdClauseOperand);
}
// In FORTRAN `nowait` clause occur at the end of `omp do` directive.
// i.e
@@ -1828,23 +1784,36 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
if (endClauseList) {
if (ClauseProcessor(converter, semaCtx, *endClauseList)
.processNowait(nowaitClauseOperand))
- wsLoopOp.setNowaitAttr(nowaitClauseOperand);
+ wsloopOp.setNowaitAttr(nowaitClauseOperand);
}
+ // Create omp.wsloop wrapper and populate entry block arguments with reduction
+ // variables.
+ llvm::SmallVector<mlir::Location> reductionLocs(reductionSymbols.size(), loc);
+ mlir::Block *wsloopEntryBlock = firOpBuilder.createBlock(
+ &wsloopOp.getRegion(), {}, reductionTypes, reductionLocs);
+ firOpBuilder.setInsertionPoint(
+ Fortran::lower::genOpenMPTerminator(firOpBuilder, wsloopOp, loc));
+
+ // Create nested omp.loop_nest and fill body with loop contents.
+ auto loopOp = firOpBuilder.create<mlir::omp::LoopNestOp>(
+ loc, lowerBound, upperBound, step,
+ /*inclusive=*/firOpBuilder.getUnitAttr());
+
auto *nestedEval = getCollapsedLoopEval(
eval, Fortran::lower::getCollapseValue(beginClauseList));
auto ivCallback = [&](mlir::Operation *op) {
- return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols,
- reductionTypes);
+ return genLoopVars(op, converter, loc, iv, reductionSymbols,
+ wsloopEntryBlock->getArguments());
};
createBodyOfOp<mlir::omp::WsloopOp>(
- *wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
- .setClauses(&beginClauseList)
- .setDataSharingProcessor(&dsp)
- .setReductions(&reductionSymbols, &reductionTypes)
- .setGenRegionEntryCb(ivCallback));
+ *loopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
+ .setClauses(&beginClauseList)
+ .setDataSharingProcessor(&dsp)
+ .setReductions(&reductionSymbols, &reductionTypes)
+ .setGenRegionEntryCb(ivCallback));
}
static void createSimdWsloop(
@@ -2430,8 +2399,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
mlir::Operation *Fortran::lower::genOpenMPTerminator(fir::FirOpBuilder &builder,
mlir::Operation *op,
mlir::Location loc) {
- if (mlir::isa<mlir::omp::WsloopOp, mlir::omp::DeclareReductionOp,
- mlir::omp::AtomicUpdateOp, mlir::omp::LoopNestOp>(op))
+ if (mlir::isa<mlir::omp::AtomicUpdateOp, mlir::omp::DeclareReductionOp,
+ mlir::omp::LoopNestOp>(op))
return builder.create<mlir::omp::YieldOp>(loc);
return builder.create<mlir::omp::TerminatorOp>(loc);
}
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index fa7979e8875afc..c7c609bbb35623 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -7,15 +7,17 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
omp.parallel {
%1 = fir.alloca i32 {adapt.valuebyref, pinned}
%2 = fir.load %arg0 : !fir.ref<i32>
- omp.wsloop nowait
- for (%arg2) : i32 = (%c1_i32) to (%2) inclusive step (%c1_i32) {
- fir.store %arg2 to %1 : !fir.ref<i32>
- %3 = fir.load %1 : !fir.ref<i32>
- %4 = fir.convert %3 : (i32) -> i64
- %5 = arith.subi %4, %c1_i64 : i64
- %6 = fir.coordinate_of %arg1, %5 : (!fir.ref<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
- fir.store %3 to %6 : !fir.ref<i32>
- omp.yield
+ omp.wsloop nowait {
+ omp.loop_nest (%arg2) : i32 = (%c1_i32) to (%2) inclusive step (%c1_i32) {
+ fir.store %arg2 to %1 : !fir.ref<i32>
+ %3 = fir.load %1 : !fir.ref<i32>
+ %4 = fir.convert %3 : (i32) -> i64
+ %5 = arith.subi %4, %c1_i64 : i64
+ %6 = fir.coordinate_of %arg1, %5 : (!fir.ref<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
+ fir.store %3 to %6 : !fir.ref<i32>
+ omp.yield
+ }
+ omp.terminator
}
omp.terminator
}
@@ -31,7 +33,7 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
// CHECK: %[[I_VAR:.*]] = llvm.alloca %[[ONE_3]] x i32 {pinned} : (i64) -> !llvm.ptr
// CHECK: %[[N:.*]] = llvm.load %[[N_REF]] : !llvm.ptr -> i32
// CHECK: omp.wsloop nowait
-// CHECK-SAME: for (%[[I:.*]]) : i32 = (%[[ONE_2]]) to (%[[N]]) inclusive step (%[[ONE_2]]) {
+// CHECK-NEXT: omp.loop_nest (%[[I:.*]]) : i32 = (%[[ONE_2]]) to (%[[N]]) inclusive step (%[[ONE_2]]) {
// CHECK: llvm.store %[[I]], %[[I_VAR]] : i32, !llvm.ptr
// CHECK: %[[I1:.*]] = llvm.load %[[I_VAR]] : !llvm.ptr -> i32
// CHECK: %[[I1_EXT:.*]] = llvm.sext %[[I1]] : i32 to i64
@@ -42,6 +44,8 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
// CHECK: }
// CHECK: omp.terminator
// CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
// CHECK: llvm.return
// CHECK: }
@@ -79,13 +83,16 @@ func.func @_QPsb(%arr: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr"}) {
omp.parallel {
%c1 = arith.constant 1 : i32
%c50 = arith.constant 50 : i32
- omp.wsloop for (%indx) : i32 = (%c1) to (%c50) inclusive step (%c1) {
- %1 = fir.convert %indx : (i32) -> i64
- %c1_i64 = arith.constant 1 : i64
- %2 = arith.subi %1, %c1_i64 : i64
- %3 = fir.coordinate_of %arr, %2 : (!fir.box<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
- fir.store %indx to %3 : !fir.ref<i32>
- omp.yield
+ omp.wsloop {
+ omp.loop_nest (%indx) : i32 = (%c1) to (%c50) inclusive step (%c1) {
+ %1 = fir.convert %indx : (i32) -> i64
+ %c1_i64 = arith.constant 1 : i64
+ %2 = arith.subi %1, %c1_i64 : i64
+ %3 = fir.coordinate_of %arr, %2 : (!fir.box<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
+ fir.store %indx to %3 : !fir.ref<i32>
+ omp.yield
+ }
+ omp.terminator
}
omp.terminator
}
@@ -98,9 +105,11 @@ func.func @_QPsb(%arr: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr"}) {
// CHECK: omp.parallel {
// CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
// CHECK: %[[C50:.*]] = llvm.mlir.constant(50 : i32) : i32
-// CHECK: omp.wsloop for (%[[INDX:.*]]) : i32 = (%[[C1]]) to (%[[C50]]) inclusive step (%[[C1]]) {
-// CHECK: llvm.store %[[INDX]], %{{.*}} : i32, !llvm.ptr
-// CHECK: omp.yield
+// CHECK: omp.wsloop {
+// CHECK-NEXT: omp.loop_nest (%[[INDX:.*]]) : i32 = (%[[C1]]) to (%[[C50]]) inclusive step (%[[C1]]) {
+// CHECK: llvm.store %[[INDX]], %{{.*}} : i32, !llvm.ptr
+// CHECK: omp.yield
+// CHECK: omp.terminator
// CHECK: omp.terminator
// CHECK: llvm.return
@@ -708,18 +717,20 @@ func.func @_QPsb() {
// CHECK-SAME: %[[ARRAY_REF:.*]]: !llvm.ptr
// CHECK: %[[RED_ACCUMULATOR:.*]] = llvm.alloca %2 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
// CHECK: omp.parallel {
-// CHECK: omp.wsloop reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) for
-// CHECK: %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
-// CHECK: %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
-// CHECK: %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
-// CHECK: %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
-// CHECK: %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
-// CHECK: %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
-// CHECK: %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
-// CHECK: %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
-// CHECK: %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
-// CHECK: llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
-// CHECK: omp.yield
+// CHECK: omp.wsloop reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) {
+// CHECK-NEXT: omp.loop_nest
+// CHECK: %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
+// CHECK: %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
+// CHECK: %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
+// CHECK: %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
+// CHECK: %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
+// CHECK: %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
+// CHECK: %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
+// CHECK: %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
+// CHECK: %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
+// CHECK: llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
+// CHECK: omp.yield
+// CHECK: omp.terminator
// CHECK: omp.terminator
// CHECK: llvm.return
@@ -747,21 +758,24 @@ func.func @_QPsimple_reduction(%arg0: !fir.ref<!fir.array<100x!fir.logical<4>>>
%c1_i32 = arith.constant 1 : i32
%c100_i32 = arith.constant 100 : i32
%c1_i32_0 = arith.constant 1 : i32
- omp.wsloop reduction(@eqv_reduction %1 -> %prv : !fir.ref<!fir.logical<4>>) for (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
- fir.store %arg1 to %3 : !fir.ref<i32>
- %4 = fir.load %3 : !fir.ref<i32>
- %5 = fir.convert %4 : (i32) -> i64
- %c1_i64 = arith.constant 1 : i64
- %6 = arith.subi %5, %c1_i64 : i64
- %7 = fir.coordinate_of %arg0, %6 : (!fir.ref<!fir.array<100x!fir.logical<4>>>, i64) -> !fir.ref<!fir.logical<4>>
- %8 = fir.load %7 : !fir.ref<!fir.logical<4>>
- %lprv = fir.load %prv : !fir.ref<!fir.logical<4>>
- %lprv1 = fir.convert %lprv : (!fir.logical<...
[truncated]
``````````
</details>
https://github.com/llvm/llvm-project/pull/88403
More information about the llvm-branch-commits
mailing list