[llvm-branch-commits] [flang] [flang] Introduce ws loop nest generation for HLFIR lowering (PR #101445)

Ivan R. Ivanov via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Wed Jul 31 19:59:04 PDT 2024


https://github.com/ivanradanov updated https://github.com/llvm/llvm-project/pull/101445

>From 4da93bb2a99ac1d59d4924c518503c94ec81c659 Mon Sep 17 00:00:00 2001
From: Ivan Radanov Ivanov <ivanov.i.aa at m.titech.ac.jp>
Date: Wed, 31 Jul 2024 14:12:34 +0900
Subject: [PATCH] [flang] Introduce ws loop nest generation for HLFIR lowering

---
 .../flang/Optimizer/Builder/HLFIRTools.h      | 12 +++--
 flang/lib/Lower/ConvertCall.cpp               |  2 +-
 flang/lib/Lower/OpenMP/ReductionProcessor.cpp |  4 +-
 flang/lib/Optimizer/Builder/HLFIRTools.cpp    | 52 ++++++++++++++-----
 .../HLFIR/Transforms/BufferizeHLFIR.cpp       |  3 +-
 .../LowerHLFIROrderedAssignments.cpp          | 30 +++++------
 .../Transforms/OptimizedBufferization.cpp     |  6 +--
 7 files changed, 69 insertions(+), 40 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/HLFIRTools.h b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
index 6b41025eea078..14e42c6f358e4 100644
--- a/flang/include/flang/Optimizer/Builder/HLFIRTools.h
+++ b/flang/include/flang/Optimizer/Builder/HLFIRTools.h
@@ -357,8 +357,8 @@ hlfir::ElementalOp genElementalOp(
 
 /// Structure to describe a loop nest.
 struct LoopNest {
-  fir::DoLoopOp outerLoop;
-  fir::DoLoopOp innerLoop;
+  mlir::Operation *outerOp;
+  mlir::Block *body;
   llvm::SmallVector<mlir::Value> oneBasedIndices;
 };
 
@@ -366,11 +366,13 @@ struct LoopNest {
 /// \p isUnordered specifies whether the loops in the loop nest
 /// are unordered.
 LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
-                     mlir::ValueRange extents, bool isUnordered = false);
+                     mlir::ValueRange extents, bool isUnordered = false,
+                     bool emitWsLoop = false);
 inline LoopNest genLoopNest(mlir::Location loc, fir::FirOpBuilder &builder,
-                            mlir::Value shape, bool isUnordered = false) {
+                            mlir::Value shape, bool isUnordered = false,
+                            bool emitWsLoop = false) {
   return genLoopNest(loc, builder, getIndexExtents(loc, builder, shape),
-                     isUnordered);
+                     isUnordered, emitWsLoop);
 }
 
 /// Inline the body of an hlfir.elemental at the current insertion point
diff --git a/flang/lib/Lower/ConvertCall.cpp b/flang/lib/Lower/ConvertCall.cpp
index fd873f55dd844..0689d6e033dd9 100644
--- a/flang/lib/Lower/ConvertCall.cpp
+++ b/flang/lib/Lower/ConvertCall.cpp
@@ -2128,7 +2128,7 @@ class ElementalCallBuilder {
           hlfir::genLoopNest(loc, builder, shape, !mustBeOrdered);
       mlir::ValueRange oneBasedIndices = loopNest.oneBasedIndices;
       auto insPt = builder.saveInsertionPoint();
-      builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+      builder.setInsertionPointToStart(loopNest.body);
       callContext.stmtCtx.pushScope();
       for (auto &preparedActual : loweredActuals)
         if (preparedActual)
diff --git a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
index c3c1f363033c2..72a90dd0d6f29 100644
--- a/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ReductionProcessor.cpp
@@ -375,7 +375,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
   // know this won't miss any opportuinties for clever elemental inlining
   hlfir::LoopNest nest = hlfir::genLoopNest(
       loc, builder, shapeShift.getExtents(), /*isUnordered=*/true);
-  builder.setInsertionPointToStart(nest.innerLoop.getBody());
+  builder.setInsertionPointToStart(nest.body);
   mlir::Type refTy = fir::ReferenceType::get(seqTy.getEleTy());
   auto lhsEleAddr = builder.create<fir::ArrayCoorOp>(
       loc, refTy, lhs, shapeShift, /*slice=*/mlir::Value{},
@@ -389,7 +389,7 @@ static void genBoxCombiner(fir::FirOpBuilder &builder, mlir::Location loc,
       builder, loc, redId, refTy, lhsEle, rhsEle);
   builder.create<fir::StoreOp>(loc, scalarReduction, lhsEleAddr);
 
-  builder.setInsertionPointAfter(nest.outerLoop);
+  builder.setInsertionPointAfter(nest.outerOp);
   builder.create<mlir::omp::YieldOp>(loc, lhsAddr);
 }
 
diff --git a/flang/lib/Optimizer/Builder/HLFIRTools.cpp b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
index 8d0ae2f195178..cd07cb741eb4b 100644
--- a/flang/lib/Optimizer/Builder/HLFIRTools.cpp
+++ b/flang/lib/Optimizer/Builder/HLFIRTools.cpp
@@ -20,6 +20,7 @@
 #include "mlir/IR/IRMapping.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/TypeSwitch.h"
+#include <mlir/Dialect/OpenMP/OpenMPDialect.h>
 #include <optional>
 
 // Return explicit extents. If the base is a fir.box, this won't read it to
@@ -855,26 +856,51 @@ mlir::Value hlfir::inlineElementalOp(
 
 hlfir::LoopNest hlfir::genLoopNest(mlir::Location loc,
                                    fir::FirOpBuilder &builder,
-                                   mlir::ValueRange extents, bool isUnordered) {
+                                   mlir::ValueRange extents, bool isUnordered,
+                                   bool emitWsLoop) {
   hlfir::LoopNest loopNest;
   assert(!extents.empty() && "must have at least one extent");
-  auto insPt = builder.saveInsertionPoint();
+  mlir::OpBuilder::InsertionGuard guard(builder);
   loopNest.oneBasedIndices.assign(extents.size(), mlir::Value{});
   // Build loop nest from column to row.
   auto one = builder.create<mlir::arith::ConstantIndexOp>(loc, 1);
   mlir::Type indexType = builder.getIndexType();
-  unsigned dim = extents.size() - 1;
-  for (auto extent : llvm::reverse(extents)) {
-    auto ub = builder.createConvert(loc, indexType, extent);
-    loopNest.innerLoop =
-        builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
-    builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
-    // Reverse the indices so they are in column-major order.
-    loopNest.oneBasedIndices[dim--] = loopNest.innerLoop.getInductionVar();
-    if (!loopNest.outerLoop)
-      loopNest.outerLoop = loopNest.innerLoop;
+  if (emitWsLoop) {
+    auto wsloop = builder.create<mlir::omp::WsloopOp>(
+        loc, mlir::ArrayRef<mlir::NamedAttribute>());
+    loopNest.outerOp = wsloop;
+    builder.createBlock(&wsloop.getRegion());
+    mlir::omp::LoopNestOperands lnops;
+    lnops.loopInclusive = builder.getUnitAttr();
+    for (auto extent : llvm::reverse(extents)) {
+      lnops.loopLowerBounds.push_back(one);
+      lnops.loopUpperBounds.push_back(extent);
+      lnops.loopSteps.push_back(one);
+    }
+    auto lnOp = builder.create<mlir::omp::LoopNestOp>(loc, lnops);
+    builder.create<mlir::omp::TerminatorOp>(loc);
+    mlir::Block *block = builder.createBlock(&lnOp.getRegion());
+    for (auto extent : llvm::reverse(extents))
+      block->addArgument(extent.getType(), extent.getLoc());
+    loopNest.body = block;
+    builder.create<mlir::omp::YieldOp>(loc);
+    for (unsigned dim = 0; dim < extents.size(); dim++)
+      loopNest.oneBasedIndices[extents.size() - dim - 1] =
+          lnOp.getRegion().front().getArgument(dim);
+  } else {
+    unsigned dim = extents.size() - 1;
+    for (auto extent : llvm::reverse(extents)) {
+      auto ub = builder.createConvert(loc, indexType, extent);
+      auto doLoop =
+          builder.create<fir::DoLoopOp>(loc, one, ub, one, isUnordered);
+      loopNest.body = doLoop.getBody();
+      builder.setInsertionPointToStart(loopNest.body);
+      // Reverse the indices so they are in column-major order.
+      loopNest.oneBasedIndices[dim--] = doLoop.getInductionVar();
+      if (!loopNest.outerOp)
+        loopNest.outerOp = doLoop;
+    }
   }
-  builder.restoreInsertionPoint(insPt);
   return loopNest;
 }
 
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index a70a6b388c4b1..b608677c52631 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -31,6 +31,7 @@
 #include "mlir/Pass/Pass.h"
 #include "mlir/Pass/PassManager.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
 #include "llvm/ADT/TypeSwitch.h"
 
 namespace hlfir {
@@ -793,7 +794,7 @@ struct ElementalOpConversion
     hlfir::LoopNest loopNest =
         hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
     auto insPt = builder.saveInsertionPoint();
-    builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+    builder.setInsertionPointToStart(loopNest.body);
     auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
                                           loopNest.oneBasedIndices);
     hlfir::Entity elementValue(yield.getElementValue());
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
index 85dd517cb5791..645abf65d10a3 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/LowerHLFIROrderedAssignments.cpp
@@ -464,7 +464,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
       // if the LHS is not).
       mlir::Value shape = hlfir::genShape(loc, builder, lhsEntity);
       elementalLoopNest = hlfir::genLoopNest(loc, builder, shape);
-      builder.setInsertionPointToStart(elementalLoopNest->innerLoop.getBody());
+      builder.setInsertionPointToStart(elementalLoopNest->body);
       lhsEntity = hlfir::getElementAt(loc, builder, lhsEntity,
                                       elementalLoopNest->oneBasedIndices);
       rhsEntity = hlfir::getElementAt(loc, builder, rhsEntity,
@@ -484,7 +484,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
     for (auto &cleanupConversion : argConversionCleanups)
       cleanupConversion();
     if (elementalLoopNest)
-      builder.setInsertionPointAfter(elementalLoopNest->outerLoop);
+      builder.setInsertionPointAfter(elementalLoopNest->outerOp);
   } else {
     // TODO: preserve allocatable assignment aspects for forall once
     // they are conveyed in hlfir.region_assign.
@@ -493,7 +493,7 @@ void OrderedAssignmentRewriter::pre(hlfir::RegionAssignOp regionAssignOp) {
   generateCleanupIfAny(loweredLhs.elementalCleanup);
   if (loweredLhs.vectorSubscriptLoopNest)
     builder.setInsertionPointAfter(
-        loweredLhs.vectorSubscriptLoopNest->outerLoop);
+        loweredLhs.vectorSubscriptLoopNest->outerOp);
   generateCleanupIfAny(oldRhsYield);
   generateCleanupIfAny(loweredLhs.nonElementalCleanup);
 }
@@ -518,8 +518,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
       hlfir::Entity savedMask{maybeSaved->first};
       mlir::Value shape = hlfir::genShape(loc, builder, savedMask);
       whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
-      constructStack.push_back(whereLoopNest->outerLoop.getOperation());
-      builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
+      constructStack.push_back(whereLoopNest->outerOp);
+      builder.setInsertionPointToStart(whereLoopNest->body);
       mlir::Value cdt = hlfir::getElementAt(loc, builder, savedMask,
                                             whereLoopNest->oneBasedIndices);
       generateMaskIfOp(cdt);
@@ -527,7 +527,7 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
         // If this is the same run as the one that saved the value, the clean-up
         // was left-over to be done now.
         auto insertionPoint = builder.saveInsertionPoint();
-        builder.setInsertionPointAfter(whereLoopNest->outerLoop);
+        builder.setInsertionPointAfter(whereLoopNest->outerOp);
         generateCleanupIfAny(maybeSaved->second);
         builder.restoreInsertionPoint(insertionPoint);
       }
@@ -539,8 +539,8 @@ void OrderedAssignmentRewriter::pre(hlfir::WhereOp whereOp) {
     mask.generateNoneElementalPart(builder, mapper);
     mlir::Value shape = mask.generateShape(builder, mapper);
     whereLoopNest = hlfir::genLoopNest(loc, builder, shape);
-    constructStack.push_back(whereLoopNest->outerLoop.getOperation());
-    builder.setInsertionPointToStart(whereLoopNest->innerLoop.getBody());
+    constructStack.push_back(whereLoopNest->outerOp);
+    builder.setInsertionPointToStart(whereLoopNest->body);
     mlir::Value cdt = generateMaskedEntity(mask);
     generateMaskIfOp(cdt);
     return;
@@ -754,7 +754,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
       loweredLhs.vectorSubscriptLoopNest = hlfir::genLoopNest(
           loc, builder, loweredLhs.vectorSubscriptShape.value());
       builder.setInsertionPointToStart(
-          loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
+          loweredLhs.vectorSubscriptLoopNest->body);
     }
     loweredLhs.lhs = temp->second.fetch(loc, builder);
     return loweredLhs;
@@ -772,7 +772,7 @@ OrderedAssignmentRewriter::generateYieldedLHS(
         hlfir::genLoopNest(loc, builder, *loweredLhs.vectorSubscriptShape,
                            !elementalAddrLhs.isOrdered());
     builder.setInsertionPointToStart(
-        loweredLhs.vectorSubscriptLoopNest->innerLoop.getBody());
+        loweredLhs.vectorSubscriptLoopNest->body);
     mapper.map(elementalAddrLhs.getIndices(),
                loweredLhs.vectorSubscriptLoopNest->oneBasedIndices);
     for (auto &op : elementalAddrLhs.getBody().front().without_terminator())
@@ -798,11 +798,11 @@ OrderedAssignmentRewriter::generateMaskedEntity(MaskedArrayExpr &maskedExpr) {
   if (!maskedExpr.noneElementalPartWasGenerated) {
     // Generate none elemental part before the where loops (but inside the
     // current forall loops if any).
-    builder.setInsertionPoint(whereLoopNest->outerLoop);
+    builder.setInsertionPoint(whereLoopNest->outerOp);
     maskedExpr.generateNoneElementalPart(builder, mapper);
   }
   // Generate the none elemental part cleanup after the where loops.
-  builder.setInsertionPointAfter(whereLoopNest->outerLoop);
+  builder.setInsertionPointAfter(whereLoopNest->outerOp);
   maskedExpr.generateNoneElementalCleanupIfAny(builder, mapper);
   // Generate the value of the current element for the masked expression
   // at the current insertion point (inside the where loops, and any fir.if
@@ -1242,7 +1242,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
   LhsValueAndCleanUp loweredLhs = generateYieldedLHS(loc, region);
   fir::factory::TemporaryStorage *temp = nullptr;
   if (loweredLhs.vectorSubscriptLoopNest)
-    constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerLoop);
+    constructStack.push_back(loweredLhs.vectorSubscriptLoopNest->outerOp);
   if (loweredLhs.vectorSubscriptLoopNest && !rhsIsArray(regionAssignOp)) {
     // Vector subscripted entity for which the shape must also be saved on top
     // of the element addresses (e.g. the shape may change in each forall
@@ -1265,7 +1265,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
     // subscripted LHS.
     auto &vectorTmp = temp->cast<fir::factory::AnyVectorSubscriptStack>();
     auto insertionPoint = builder.saveInsertionPoint();
-    builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerLoop);
+    builder.setInsertionPoint(loweredLhs.vectorSubscriptLoopNest->outerOp);
     vectorTmp.pushShape(loc, builder, shape);
     builder.restoreInsertionPoint(insertionPoint);
   } else {
@@ -1291,7 +1291,7 @@ void OrderedAssignmentRewriter::saveLeftHandSide(
   if (loweredLhs.vectorSubscriptLoopNest) {
     constructStack.pop_back();
     builder.setInsertionPointAfter(
-        loweredLhs.vectorSubscriptLoopNest->outerLoop);
+        loweredLhs.vectorSubscriptLoopNest->outerOp);
   }
 }
 
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
index c5b809514c54c..c4aed6b79df92 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/OptimizedBufferization.cpp
@@ -483,7 +483,7 @@ llvm::LogicalResult ElementalAssignBufferization::matchAndRewrite(
   // hlfir.elemental region inside the inner loop
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, !elemental.isOrdered());
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto yield = hlfir::inlineElementalOp(loc, builder, elemental,
                                         loopNest.oneBasedIndices);
   hlfir::Entity elementValue{yield.getElementValue()};
@@ -554,7 +554,7 @@ llvm::LogicalResult BroadcastAssignBufferization::matchAndRewrite(
       hlfir::getIndexExtents(loc, builder, shape);
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto arrayElement =
       hlfir::getElementAt(loc, builder, lhs, loopNest.oneBasedIndices);
   builder.create<hlfir::AssignOp>(loc, rhs, arrayElement);
@@ -649,7 +649,7 @@ llvm::LogicalResult VariableAssignBufferization::matchAndRewrite(
       hlfir::getIndexExtents(loc, builder, shape);
   hlfir::LoopNest loopNest =
       hlfir::genLoopNest(loc, builder, extents, /*isUnordered=*/true);
-  builder.setInsertionPointToStart(loopNest.innerLoop.getBody());
+  builder.setInsertionPointToStart(loopNest.body);
   auto rhsArrayElement =
       hlfir::getElementAt(loc, builder, rhs, loopNest.oneBasedIndices);
   rhsArrayElement = hlfir::loadTrivialScalar(loc, builder, rhsArrayElement);



More information about the llvm-branch-commits mailing list