[llvm-branch-commits] [flang] [mlir] [MLIR][Flang][OpenMP] Make omp.wsloop into a loop wrapper (PR #88403)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Apr 11 08:56:27 PDT 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-openmp

Author: Sergio Afonso (skatrak)

<details>
<summary>Changes</summary>

This patch updates the definition of `omp.wsloop` to enforce the restrictions of a wrapper operation. Given the widespread use of this operation, the changes introduced in this patch are several:

- Update the MLIR definition of the `omp.wsloop`, as well as parser/printer, builder and verifier.
- Update verifiers for `omp.ordered.region`, `omp.cancel` and `omp.cancellation_point` to correctly check for a parent `omp.wsloop`.
- Update MLIR to LLVM IR translation of `omp.wsloop` to keep working after the change in representation. Another patch should be created to reduce the current code duplication between `omp.wsloop` and `omp.simd` after introducing a common `omp.loop_nest` operation.
- Update the `scf.parallel` lowering pass to OpenMP to produce the new expected representation.
- Update flang lowering to implement `omp.wsloop` representation changes, including changes to `lastprivate`, and `reduction` handling to avoid adding operations into a wrapper and attach entry block arguments to the right operation.
- Fix unit tests broken due to the representation change.

---

Patch is 758.19 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/88403.diff


110 Files Affected:

- (modified) flang/lib/Lower/OpenMP/DataSharingProcessor.cpp (+27-21) 
- (modified) flang/lib/Lower/OpenMP/OpenMP.cpp (+43-74) 
- (modified) flang/test/Fir/convert-to-llvm-openmp-and-fir.fir (+61-47) 
- (modified) flang/test/Lower/OpenMP/FIR/copyin.f90 (+11-5) 
- (modified) flang/test/Lower/OpenMP/FIR/lastprivate-commonblock.f90 (+4-1) 
- (modified) flang/test/Lower/OpenMP/FIR/location.f90 (+10-7) 
- (modified) flang/test/Lower/OpenMP/FIR/parallel-lastprivate-clause-scalar.f90 (+36-12) 
- (modified) flang/test/Lower/OpenMP/FIR/parallel-private-clause-fixes.f90 (+26-23) 
- (modified) flang/test/Lower/OpenMP/FIR/parallel-private-clause.f90 (+60-54) 
- (modified) flang/test/Lower/OpenMP/FIR/parallel-wsloop-firstpriv.f90 (+10-2) 
- (modified) flang/test/Lower/OpenMP/FIR/parallel-wsloop.f90 (+74-54) 
- (modified) flang/test/Lower/OpenMP/FIR/stop-stmt-in-region.f90 (+21-18) 
- (modified) flang/test/Lower/OpenMP/FIR/target.f90 (+4-1) 
- (modified) flang/test/Lower/OpenMP/FIR/unstructured.f90 (+110-89) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-chunks.f90 (+28-19) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-collapse.f90 (+16-13) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-monotonic.f90 (+17-13) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-nonmonotonic.f90 (+17-14) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-ordered.f90 (+12-6) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add-byref.f90 (+106-85) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-add.f90 (+106-85) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-iand-byref.f90 (+3-1) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-iand.f90 (+3-1) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ieor-byref.f90 (+3-1) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ieor.f90 (+3-1) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ior-byref.f90 (+3-1) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-ior.f90 (+3-1) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv-byref.f90 (+75-69) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-eqv.f90 (+75-69) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv-byref.f90 (+75-69) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-logical-neqv.f90 (+75-69) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-max-byref.f90 (+18-13) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-max.f90 (+18-13) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-min-byref.f90 (+18-14) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-reduction-min.f90 (+18-14) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-simd.f90 (+16-13) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop-variable.f90 (+93-79) 
- (modified) flang/test/Lower/OpenMP/FIR/wsloop.f90 (+36-30) 
- (modified) flang/test/Lower/OpenMP/Todo/omp-default-clause-inner-loop.f90 (+4-1) 
- (modified) flang/test/Lower/OpenMP/copyin.f90 (+17-11) 
- (modified) flang/test/Lower/OpenMP/default-clause-byref.f90 (+4-1) 
- (modified) flang/test/Lower/OpenMP/default-clause.f90 (+4-1) 
- (modified) flang/test/Lower/OpenMP/hlfir-wsloop.f90 (+7-5) 
- (modified) flang/test/Lower/OpenMP/lastprivate-commonblock.f90 (+31-28) 
- (modified) flang/test/Lower/OpenMP/lastprivate-iv.f90 (+48-42) 
- (modified) flang/test/Lower/OpenMP/location.f90 (+10-7) 
- (modified) flang/test/Lower/OpenMP/parallel-lastprivate-clause-scalar.f90 (+36-12) 
- (modified) flang/test/Lower/OpenMP/parallel-private-clause-fixes.f90 (+26-23) 
- (modified) flang/test/Lower/OpenMP/parallel-private-clause.f90 (+56-50) 
- (modified) flang/test/Lower/OpenMP/parallel-wsloop-firstpriv.f90 (+8-2) 
- (modified) flang/test/Lower/OpenMP/parallel-wsloop.f90 (+79-59) 
- (modified) flang/test/Lower/OpenMP/stop-stmt-in-region.f90 (+21-18) 
- (modified) flang/test/Lower/OpenMP/target.f90 (+4-1) 
- (modified) flang/test/Lower/OpenMP/unstructured.f90 (+110-89) 
- (modified) flang/test/Lower/OpenMP/wsloop-chunks.f90 (+28-19) 
- (modified) flang/test/Lower/OpenMP/wsloop-collapse.f90 (+16-13) 
- (modified) flang/test/Lower/OpenMP/wsloop-monotonic.f90 (+10-8) 
- (modified) flang/test/Lower/OpenMP/wsloop-nonmonotonic.f90 (+11-8) 
- (modified) flang/test/Lower/OpenMP/wsloop-ordered.f90 (+12-6) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-byref.f90 (+120-99) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir-byref.f90 (+10-8) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-add-hlfir.f90 (+10-8) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-add.f90 (+120-99) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-array.f90 (+19-16) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-array2.f90 (+27-24) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-iand-byref.f90 (+13-11) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-iand.f90 (+13-11) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-ieor-byref.f90 (+4-2) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-ieor.f90 (+4-2) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-ior-byref.f90 (+13-11) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-ior.f90 (+13-11) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-and-byref.f90 (+70-64) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-and.f90 (+70-64) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv-byref.f90 (+70-64) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-eqv.f90 (+70-64) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv-byref.f90 (+70-64) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-neqv.f90 (+70-64) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-or-byref.f90 (+70-64) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-logical-or.f90 (+70-64) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-byref.f90 (+48-42) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir-byref.f90 (+14-12) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-max-hlfir.f90 (+14-12) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-max.f90 (+48-42) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-min-byref.f90 (+49-43) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-min.f90 (+49-43) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-min2.f90 (+9-7) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-mul-byref.f90 (+113-99) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-mul.f90 (+113-99) 
- (modified) flang/test/Lower/OpenMP/wsloop-reduction-multi.f90 (+26-23) 
- (modified) flang/test/Lower/OpenMP/wsloop-simd.f90 (+16-13) 
- (modified) flang/test/Lower/OpenMP/wsloop-unstructured.f90 (+21-18) 
- (modified) flang/test/Lower/OpenMP/wsloop-variable.f90 (+91-76) 
- (modified) flang/test/Lower/OpenMP/wsloop.f90 (+39-33) 
- (modified) mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td (+21-35) 
- (modified) mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp (+40-12) 
- (modified) mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp (+47-85) 
- (modified) mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp (+32-34) 
- (modified) mlir/test/CAPI/execution_engine.c (+5-2) 
- (modified) mlir/test/Conversion/OpenMPToLLVM/convert-to-llvmir.mlir (+36-28) 
- (modified) mlir/test/Conversion/SCFToOpenMP/reductions.mlir (+4) 
- (modified) mlir/test/Conversion/SCFToOpenMP/scf-to-openmp.mlir (+23-8) 
- (modified) mlir/test/Dialect/LLVMIR/legalize-for-export.mlir (+11-8) 
- (modified) mlir/test/Dialect/OpenMP/invalid.mlir (+167-90) 
- (modified) mlir/test/Dialect/OpenMP/ops.mlir (+326-236) 
- (modified) mlir/test/Target/LLVMIR/omptarget-parallel-wsloop.mlir (+7-4) 
- (modified) mlir/test/Target/LLVMIR/omptarget-wsloop-collapsed.mlir (+10-7) 
- (modified) mlir/test/Target/LLVMIR/omptarget-wsloop.mlir (+12-6) 
- (modified) mlir/test/Target/LLVMIR/openmp-llvm.mlir (+404-337) 
- (modified) mlir/test/Target/LLVMIR/openmp-nested.mlir (+18-12) 
- (modified) mlir/test/Target/LLVMIR/openmp-reduction.mlir (+63-50) 


``````````diff
diff --git a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
index e114ab9f4548ab..645c351ac6c08c 100644
--- a/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/DataSharingProcessor.cpp
@@ -133,8 +133,14 @@ void DataSharingProcessor::insertBarrier() {
 }
 
 void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
+  mlir::OpBuilder::InsertionGuard guard(firOpBuilder);
+  mlir::omp::LoopNestOp loopOp;
+  if (auto wrapper = mlir::dyn_cast<mlir::omp::LoopWrapperInterface>(op))
+    loopOp = wrapper.isWrapper()
+                 ? mlir::cast<mlir::omp::LoopNestOp>(wrapper.getWrappedLoop())
+                 : nullptr;
+
   bool cmpCreated = false;
-  mlir::OpBuilder::InsertPoint localInsPt = firOpBuilder.saveInsertionPoint();
   for (const omp::Clause &clause : clauses) {
     if (clause.id != llvm::omp::OMPC_lastprivate)
       continue;
@@ -213,18 +219,20 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
       // Update the original variable just before exiting the worksharing
       // loop. Conversion as follows:
       //
-      //                       omp.wsloop {
-      // omp.wsloop {            ...
-      //    ...                  store
-      //    store       ===>     %v = arith.addi %iv, %step
-      //    omp.yield            %cmp = %step < 0 ? %v < %ub : %v > %ub
-      // }                       fir.if %cmp {
-      //                           fir.store %v to %loopIV
-      //                           ^%lpv_update_blk:
-      //                         }
-      //                         omp.yield
-      //                       }
-      //
+      // omp.wsloop {             omp.wsloop {
+      //   omp.loop_nest {          omp.loop_nest {
+      //     ...                      ...
+      //     store          ===>      store
+      //     omp.yield                %v = arith.addi %iv, %step
+      //   }                          %cmp = %step < 0 ? %v < %ub : %v > %ub
+      //   omp.terminator             fir.if %cmp {
+      // }                              fir.store %v to %loopIV
+      //                                ^%lpv_update_blk:
+      //                              }
+      //                              omp.yield
+      //                            }
+      //                            omp.terminator
+      //                          }
 
       // Only generate the compare once in presence of multiple LastPrivate
       // clauses.
@@ -232,14 +240,13 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
         continue;
       cmpCreated = true;
 
-      mlir::Location loc = op->getLoc();
-      mlir::Operation *lastOper = op->getRegion(0).back().getTerminator();
+      mlir::Location loc = loopOp.getLoc();
+      mlir::Operation *lastOper = loopOp.getRegion().back().getTerminator();
       firOpBuilder.setInsertionPoint(lastOper);
 
-      mlir::Value iv = op->getRegion(0).front().getArguments()[0];
-      mlir::Value ub =
-          mlir::dyn_cast<mlir::omp::WsloopOp>(op).getUpperBound()[0];
-      mlir::Value step = mlir::dyn_cast<mlir::omp::WsloopOp>(op).getStep()[0];
+      mlir::Value iv = loopOp.getIVs()[0];
+      mlir::Value ub = loopOp.getUpperBound()[0];
+      mlir::Value step = loopOp.getStep()[0];
 
       // v = iv + step
       // cmp = step < 0 ? v < ub : v > ub
@@ -258,7 +265,7 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
       auto ifOp = firOpBuilder.create<fir::IfOp>(loc, cmpOp, /*else*/ false);
       firOpBuilder.setInsertionPointToStart(&ifOp.getThenRegion().front());
       assert(loopIV && "loopIV was not set");
-      firOpBuilder.create<fir::StoreOp>(op->getLoc(), v, loopIV);
+      firOpBuilder.create<fir::StoreOp>(loopOp.getLoc(), v, loopIV);
       lastPrivIP = firOpBuilder.saveInsertionPoint();
     } else {
       TODO(converter.getCurrentLocation(),
@@ -266,7 +273,6 @@ void DataSharingProcessor::insertLastPrivateCompare(mlir::Operation *op) {
            "simd/worksharing-loop");
     }
   }
-  firOpBuilder.restoreInsertionPoint(localInsPt);
 }
 
 void DataSharingProcessor::collectSymbols(
diff --git a/flang/lib/Lower/OpenMP/OpenMP.cpp b/flang/lib/Lower/OpenMP/OpenMP.cpp
index 1800fcb19dcd2e..b21351382b6bdf 100644
--- a/flang/lib/Lower/OpenMP/OpenMP.cpp
+++ b/flang/lib/Lower/OpenMP/OpenMP.cpp
@@ -1626,7 +1626,9 @@ genOmpFlush(Fortran::lower::AbstractConverter &converter,
 static llvm::SmallVector<const Fortran::semantics::Symbol *>
 genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
             mlir::Location &loc,
-            llvm::ArrayRef<const Fortran::semantics::Symbol *> args) {
+            llvm::ArrayRef<const Fortran::semantics::Symbol *> args,
+            llvm::ArrayRef<const Fortran::semantics::Symbol *> wrapperSyms = {},
+            llvm::ArrayRef<mlir::BlockArgument> wrapperArgs = {}) {
   fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
   auto &region = op->getRegion(0);
 
@@ -1637,6 +1639,14 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
   llvm::SmallVector<mlir::Type> tiv(args.size(), loopVarType);
   llvm::SmallVector<mlir::Location> locs(args.size(), loc);
   firOpBuilder.createBlock(&region, {}, tiv, locs);
+
+  // Bind the entry block arguments of parent wrappers to the corresponding
+  // symbols. Do it here so that any hlfir.declare operations created as a
+  // result are inserted inside of the omp.loop_nest rather than the wrapper
+  // operations.
+  for (auto [arg, prv] : llvm::zip_equal(wrapperSyms, wrapperArgs))
+    converter.bindSymbol(*arg, prv);
+
   // The argument is not currently in memory, so make a temporary for the
   // argument, and store it there, then bind that location to the argument.
   mlir::Operation *storeOp = nullptr;
@@ -1650,58 +1660,6 @@ genLoopVars(mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
   return llvm::SmallVector<const Fortran::semantics::Symbol *>(args);
 }
 
-static llvm::SmallVector<const Fortran::semantics::Symbol *>
-genLoopAndReductionVars(
-    mlir::Operation *op, Fortran::lower::AbstractConverter &converter,
-    mlir::Location &loc,
-    llvm::ArrayRef<const Fortran::semantics::Symbol *> loopArgs,
-    llvm::ArrayRef<const Fortran::semantics::Symbol *> reductionArgs,
-    llvm::ArrayRef<mlir::Type> reductionTypes) {
-  fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
-
-  llvm::SmallVector<mlir::Type> blockArgTypes;
-  llvm::SmallVector<mlir::Location> blockArgLocs;
-  blockArgTypes.reserve(loopArgs.size() + reductionArgs.size());
-  blockArgLocs.reserve(blockArgTypes.size());
-  mlir::Block *entryBlock;
-
-  if (loopArgs.size()) {
-    std::size_t loopVarTypeSize = 0;
-    for (const Fortran::semantics::Symbol *arg : loopArgs)
-      loopVarTypeSize = std::max(loopVarTypeSize, arg->GetUltimate().size());
-    mlir::Type loopVarType = getLoopVarType(converter, loopVarTypeSize);
-    std::fill_n(std::back_inserter(blockArgTypes), loopArgs.size(),
-                loopVarType);
-    std::fill_n(std::back_inserter(blockArgLocs), loopArgs.size(), loc);
-  }
-  if (reductionArgs.size()) {
-    llvm::copy(reductionTypes, std::back_inserter(blockArgTypes));
-    std::fill_n(std::back_inserter(blockArgLocs), reductionArgs.size(), loc);
-  }
-  entryBlock = firOpBuilder.createBlock(&op->getRegion(0), {}, blockArgTypes,
-                                        blockArgLocs);
-  // The argument is not currently in memory, so make a temporary for the
-  // argument, and store it there, then bind that location to the argument.
-  if (loopArgs.size()) {
-    mlir::Operation *storeOp = nullptr;
-    for (auto [argIndex, argSymbol] : llvm::enumerate(loopArgs)) {
-      mlir::Value indexVal =
-          fir::getBase(op->getRegion(0).front().getArgument(argIndex));
-      storeOp =
-          createAndSetPrivatizedLoopVar(converter, loc, indexVal, argSymbol);
-    }
-    firOpBuilder.setInsertionPointAfter(storeOp);
-  }
-  // Bind the reduction arguments to their block arguments
-  for (auto [arg, prv] : llvm::zip_equal(
-           reductionArgs,
-           llvm::drop_begin(entryBlock->getArguments(), loopArgs.size()))) {
-    converter.bindSymbol(*arg, prv);
-  }
-
-  return llvm::SmallVector<const Fortran::semantics::Symbol *>(loopArgs);
-}
-
 static void createSimd(Fortran::lower::AbstractConverter &converter,
                        Fortran::semantics::SemanticsContext &semaCtx,
                        Fortran::lower::pft::Evaluation &eval,
@@ -1797,28 +1755,26 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
   if (ReductionProcessor::doReductionByRef(reductionVars))
     byrefOperand = firOpBuilder.getUnitAttr();
 
-  auto wsLoopOp = firOpBuilder.create<mlir::omp::WsloopOp>(
-      loc, lowerBound, upperBound, step, linearVars, linearStepVars,
-      reductionVars,
+  auto wsloopOp = firOpBuilder.create<mlir::omp::WsloopOp>(
+      loc, linearVars, linearStepVars, reductionVars,
       reductionDeclSymbols.empty()
           ? nullptr
           : mlir::ArrayAttr::get(firOpBuilder.getContext(),
                                  reductionDeclSymbols),
       scheduleValClauseOperand, scheduleChunkClauseOperand,
-      /*schedule_modifiers=*/nullptr,
-      /*simd_modifier=*/nullptr, nowaitClauseOperand, byrefOperand,
-      orderedClauseOperand, orderClauseOperand,
-      /*inclusive=*/firOpBuilder.getUnitAttr());
+      /*schedule_modifiers=*/nullptr, /*simd_modifier=*/nullptr,
+      nowaitClauseOperand, byrefOperand, orderedClauseOperand,
+      orderClauseOperand);
 
   // Handle attribute based clauses.
   if (cp.processOrdered(orderedClauseOperand))
-    wsLoopOp.setOrderedValAttr(orderedClauseOperand);
+    wsloopOp.setOrderedValAttr(orderedClauseOperand);
 
   if (cp.processSchedule(scheduleValClauseOperand, scheduleModClauseOperand,
                          scheduleSimdClauseOperand)) {
-    wsLoopOp.setScheduleValAttr(scheduleValClauseOperand);
-    wsLoopOp.setScheduleModifierAttr(scheduleModClauseOperand);
-    wsLoopOp.setSimdModifierAttr(scheduleSimdClauseOperand);
+    wsloopOp.setScheduleValAttr(scheduleValClauseOperand);
+    wsloopOp.setScheduleModifierAttr(scheduleModClauseOperand);
+    wsloopOp.setSimdModifierAttr(scheduleSimdClauseOperand);
   }
   // In FORTRAN `nowait` clause occur at the end of `omp do` directive.
   // i.e
@@ -1828,23 +1784,36 @@ static void createWsloop(Fortran::lower::AbstractConverter &converter,
   if (endClauseList) {
     if (ClauseProcessor(converter, semaCtx, *endClauseList)
             .processNowait(nowaitClauseOperand))
-      wsLoopOp.setNowaitAttr(nowaitClauseOperand);
+      wsloopOp.setNowaitAttr(nowaitClauseOperand);
   }
 
+  // Create omp.wsloop wrapper and populate entry block arguments with reduction
+  // variables.
+  llvm::SmallVector<mlir::Location> reductionLocs(reductionSymbols.size(), loc);
+  mlir::Block *wsloopEntryBlock = firOpBuilder.createBlock(
+      &wsloopOp.getRegion(), {}, reductionTypes, reductionLocs);
+  firOpBuilder.setInsertionPoint(
+      Fortran::lower::genOpenMPTerminator(firOpBuilder, wsloopOp, loc));
+
+  // Create nested omp.loop_nest and fill body with loop contents.
+  auto loopOp = firOpBuilder.create<mlir::omp::LoopNestOp>(
+      loc, lowerBound, upperBound, step,
+      /*inclusive=*/firOpBuilder.getUnitAttr());
+
   auto *nestedEval = getCollapsedLoopEval(
       eval, Fortran::lower::getCollapseValue(beginClauseList));
 
   auto ivCallback = [&](mlir::Operation *op) {
-    return genLoopAndReductionVars(op, converter, loc, iv, reductionSymbols,
-                                   reductionTypes);
+    return genLoopVars(op, converter, loc, iv, reductionSymbols,
+                       wsloopEntryBlock->getArguments());
   };
 
   createBodyOfOp<mlir::omp::WsloopOp>(
-      *wsLoopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
-                     .setClauses(&beginClauseList)
-                     .setDataSharingProcessor(&dsp)
-                     .setReductions(&reductionSymbols, &reductionTypes)
-                     .setGenRegionEntryCb(ivCallback));
+      *loopOp, OpWithBodyGenInfo(converter, semaCtx, loc, *nestedEval)
+                   .setClauses(&beginClauseList)
+                   .setDataSharingProcessor(&dsp)
+                   .setReductions(&reductionSymbols, &reductionTypes)
+                   .setGenRegionEntryCb(ivCallback));
 }
 
 static void createSimdWsloop(
@@ -2430,8 +2399,8 @@ static void genOMP(Fortran::lower::AbstractConverter &converter,
 mlir::Operation *Fortran::lower::genOpenMPTerminator(fir::FirOpBuilder &builder,
                                                      mlir::Operation *op,
                                                      mlir::Location loc) {
-  if (mlir::isa<mlir::omp::WsloopOp, mlir::omp::DeclareReductionOp,
-                mlir::omp::AtomicUpdateOp, mlir::omp::LoopNestOp>(op))
+  if (mlir::isa<mlir::omp::AtomicUpdateOp, mlir::omp::DeclareReductionOp,
+                mlir::omp::LoopNestOp>(op))
     return builder.create<mlir::omp::YieldOp>(loc);
   return builder.create<mlir::omp::TerminatorOp>(loc);
 }
diff --git a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
index fa7979e8875afc..c7c609bbb35623 100644
--- a/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
+++ b/flang/test/Fir/convert-to-llvm-openmp-and-fir.fir
@@ -7,15 +7,17 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
   omp.parallel  {
     %1 = fir.alloca i32 {adapt.valuebyref, pinned}
     %2 = fir.load %arg0 : !fir.ref<i32>
-    omp.wsloop nowait
-    for (%arg2) : i32 = (%c1_i32) to (%2) inclusive step (%c1_i32)  {
-      fir.store %arg2 to %1 : !fir.ref<i32>
-      %3 = fir.load %1 : !fir.ref<i32>
-      %4 = fir.convert %3 : (i32) -> i64
-      %5 = arith.subi %4, %c1_i64 : i64
-      %6 = fir.coordinate_of %arg1, %5 : (!fir.ref<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
-      fir.store %3 to %6 : !fir.ref<i32>
-      omp.yield
+    omp.wsloop nowait {
+      omp.loop_nest (%arg2) : i32 = (%c1_i32) to (%2) inclusive step (%c1_i32)  {
+        fir.store %arg2 to %1 : !fir.ref<i32>
+        %3 = fir.load %1 : !fir.ref<i32>
+        %4 = fir.convert %3 : (i32) -> i64
+        %5 = arith.subi %4, %c1_i64 : i64
+        %6 = fir.coordinate_of %arg1, %5 : (!fir.ref<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
+        fir.store %3 to %6 : !fir.ref<i32>
+        omp.yield
+      }
+      omp.terminator
     }
     omp.terminator
   }
@@ -31,7 +33,7 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
 // CHECK:      %[[I_VAR:.*]] = llvm.alloca %[[ONE_3]] x i32 {pinned} : (i64) -> !llvm.ptr
 // CHECK:      %[[N:.*]] = llvm.load %[[N_REF]] : !llvm.ptr -> i32
 // CHECK: omp.wsloop nowait
-// CHECK-SAME: for (%[[I:.*]]) : i32 = (%[[ONE_2]]) to (%[[N]]) inclusive step (%[[ONE_2]]) {
+// CHECK-NEXT: omp.loop_nest (%[[I:.*]]) : i32 = (%[[ONE_2]]) to (%[[N]]) inclusive step (%[[ONE_2]]) {
 // CHECK:   llvm.store %[[I]], %[[I_VAR]] : i32, !llvm.ptr
 // CHECK:   %[[I1:.*]] = llvm.load %[[I_VAR]] : !llvm.ptr -> i32
 // CHECK:   %[[I1_EXT:.*]] = llvm.sext %[[I1]] : i32 to i64
@@ -42,6 +44,8 @@ func.func @_QPsb1(%arg0: !fir.ref<i32> {fir.bindc_name = "n"}, %arg1: !fir.ref<!
 // CHECK: }
 // CHECK: omp.terminator
 // CHECK: }
+// CHECK: omp.terminator
+// CHECK: }
 // CHECK: llvm.return
 // CHECK: }
 
@@ -79,13 +83,16 @@ func.func @_QPsb(%arr: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr"}) {
   omp.parallel   {
     %c1 = arith.constant 1 : i32
     %c50 = arith.constant 50 : i32
-    omp.wsloop   for  (%indx) : i32 = (%c1) to (%c50) inclusive step (%c1) {
-      %1 = fir.convert %indx : (i32) -> i64
-      %c1_i64 = arith.constant 1 : i64
-      %2 = arith.subi %1, %c1_i64 : i64
-      %3 = fir.coordinate_of %arr, %2 : (!fir.box<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
-      fir.store %indx to %3 : !fir.ref<i32>
-      omp.yield
+    omp.wsloop {
+      omp.loop_nest (%indx) : i32 = (%c1) to (%c50) inclusive step (%c1) {
+        %1 = fir.convert %indx : (i32) -> i64
+        %c1_i64 = arith.constant 1 : i64
+        %2 = arith.subi %1, %c1_i64 : i64
+        %3 = fir.coordinate_of %arr, %2 : (!fir.box<!fir.array<?xi32>>, i64) -> !fir.ref<i32>
+        fir.store %indx to %3 : !fir.ref<i32>
+        omp.yield
+      }
+      omp.terminator
     }
     omp.terminator
   }
@@ -98,9 +105,11 @@ func.func @_QPsb(%arr: !fir.box<!fir.array<?xi32>> {fir.bindc_name = "arr"}) {
 // CHECK:    omp.parallel   {
 // CHECK:      %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32
 // CHECK:      %[[C50:.*]] = llvm.mlir.constant(50 : i32) : i32
-// CHECK:      omp.wsloop   for  (%[[INDX:.*]]) : i32 = (%[[C1]]) to (%[[C50]]) inclusive step (%[[C1]]) {
-// CHECK:        llvm.store %[[INDX]], %{{.*}} : i32, !llvm.ptr
-// CHECK:        omp.yield
+// CHECK:      omp.wsloop {
+// CHECK-NEXT:   omp.loop_nest (%[[INDX:.*]]) : i32 = (%[[C1]]) to (%[[C50]]) inclusive step (%[[C1]]) {
+// CHECK:          llvm.store %[[INDX]], %{{.*}} : i32, !llvm.ptr
+// CHECK:          omp.yield
+// CHECK:        omp.terminator
 // CHECK:      omp.terminator
 // CHECK:    llvm.return
 
@@ -708,18 +717,20 @@ func.func @_QPsb() {
 // CHECK-SAME: %[[ARRAY_REF:.*]]: !llvm.ptr
 // CHECK:    %[[RED_ACCUMULATOR:.*]] = llvm.alloca %2 x i32 {bindc_name = "x"} : (i64) -> !llvm.ptr
 // CHECK:    omp.parallel   {
-// CHECK:      omp.wsloop   reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) for
-// CHECK:        %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
-// CHECK:        %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
-// CHECK:        %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
-// CHECK:        %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
-// CHECK:        %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
-// CHECK:        %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
-// CHECK:        %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
-// CHECK:        %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
-// CHECK:        %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
-// CHECK:        llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
-// CHECK:        omp.yield
+// CHECK:      omp.wsloop reduction(@[[EQV_REDUCTION]] %[[RED_ACCUMULATOR]] -> %[[PRV:.+]] : !llvm.ptr) {
+// CHECK-NEXT:   omp.loop_nest
+// CHECK:          %[[ARRAY_ELEM_REF:.*]] = llvm.getelementptr %[[ARRAY_REF]][0, %{{.*}}] : (!llvm.ptr, i64) -> !llvm.ptr
+// CHECK:          %[[ARRAY_ELEM:.*]] = llvm.load %[[ARRAY_ELEM_REF]] : !llvm.ptr -> i32
+// CHECK:          %[[LPRV:.+]] = llvm.load %[[PRV]] : !llvm.ptr -> i32
+// CHECK:          %[[ZERO_1:.*]] = llvm.mlir.constant(0 : i64) : i32
+// CHECK:          %[[ARGVAL_1:.*]] = llvm.icmp "ne" %[[LPRV]], %[[ZERO_1]] : i32
+// CHECK:          %[[ZERO_2:.*]] = llvm.mlir.constant(0 : i64) : i32
+// CHECK:          %[[ARGVAL_2:.*]] = llvm.icmp "ne" %[[ARRAY_ELEM]], %[[ZERO_2]] : i32
+// CHECK:          %[[RES:.*]] = llvm.icmp "eq" %[[ARGVAL_2]], %[[ARGVAL_1]] : i1
+// CHECK:          %[[RES_EXT:.*]] = llvm.zext %[[RES]] : i1 to i32
+// CHECK:          llvm.store %[[RES_EXT]], %[[PRV]] : i32, !llvm.ptr
+// CHECK:          omp.yield
+// CHECK:        omp.terminator
 // CHECK:      omp.terminator
 // CHECK:    llvm.return
 
@@ -747,21 +758,24 @@ func.func @_QPsimple_reduction(%arg0: !fir.ref<!fir.array<100x!fir.logical<4>>>
     %c1_i32 = arith.constant 1 : i32
     %c100_i32 = arith.constant 100 : i32
     %c1_i32_0 = arith.constant 1 : i32
-    omp.wsloop   reduction(@eqv_reduction %1 -> %prv : !fir.ref<!fir.logical<4>>) for  (%arg1) : i32 = (%c1_i32) to (%c100_i32) inclusive step (%c1_i32_0) {
-      fir.store %arg1 to %3 : !fir.ref<i32>
-      %4 = fir.load %3 : !fir.ref<i32>
-      %5 = fir.convert %4 : (i32) -> i64
-      %c1_i64 = arith.constant 1 : i64
-      %6 = arith.subi %5, %c1_i64 : i64
-      %7 = fir.coordinate_of %arg0, %6 : (!fir.ref<!fir.array<100x!fir.logical<4>>>, i64) -> !fir.ref<!fir.logical<4>>
-      %8 = fir.load %7 : !fir.ref<!fir.logical<4>>
-      %lprv = fir.load %prv : !fir.ref<!fir.logical<4>>
-      %lprv1 = fir.convert %lprv : (!fir.logical<...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/88403


More information about the llvm-branch-commits mailing list