[flang-commits] [flang] [mlir][Interfaces] `LoopLikeOpInterface`: Support ops with multiple regions (PR #66754)

Alexandros Lamprineas via flang-commits flang-commits at lists.llvm.org
Tue Sep 19 06:42:29 PDT 2023


https://github.com/labrinea updated https://github.com/llvm/llvm-project/pull/66754

>From 5f22e1a810c5fe67a071a737fb75c088da10c250 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Tue, 19 Sep 2023 15:25:13 +0200
Subject: [PATCH] [mlir][Interfaces] `LoopLikeOpInterface`: Support ops with
 multiple regions

This commit implements `LoopLikeOpInterface` on `scf.while`. This enables LICM (and potentially other transforms) on `scf.while`.

`LoopLikeOpInterface::getLoopBody()` can now return multiple regions.

Also fix a bug in the default implementation of `LoopLikeOpInterface::isDefinedOutsideOfLoop()`, which returned "false" for some values that are defined outside of the loop (in a nested op, in such a way that the value does not dominate the loop). This interface is currently only used for LICM and there is no way to trigger this bug, so no test is added.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 flang/lib/Optimizer/Dialect/FIROps.cpp        |  8 ++++--
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  1 +
 .../mlir/Interfaces/LoopLikeInterface.td      |  8 +++---
 .../Transforms/LoopInvariantCodeMotionUtils.h |  5 ++--
 mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp |  2 +-
 .../Conversion/VectorToGPU/VectorToGPU.cpp    |  7 +++--
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 20 +++++++-------
 .../AffineLoopInvariantCodeMotion.cpp         |  4 +--
 .../Async/Transforms/AsyncParallelFor.cpp     |  4 +--
 .../Dialect/Linalg/Transforms/Hoisting.cpp    |  3 +--
 mlir/lib/Dialect/Linalg/Transforms/Loops.cpp  |  6 ++---
 .../Linalg/Transforms/SubsetHoisting.cpp      |  2 +-
 .../Dialect/MemRef/Transforms/MultiBuffer.cpp |  8 +++---
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 20 ++++++++------
 .../SCF/IR/ValueBoundsOpInterfaceImpl.cpp     |  9 +++----
 .../BufferizableOpInterfaceImpl.cpp           | 16 +++++------
 .../Transforms/StructuralTypeConversions.cpp  |  6 ++---
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          |  4 +--
 .../Transforms/SparseGPUCodegen.cpp           |  4 +--
 mlir/lib/Dialect/Tosa/IR/TosaOps.cpp          |  2 +-
 .../Utils/LoopInvariantCodeMotionUtils.cpp    |  4 +--
 .../loop-invariant-code-motion.mlir           | 27 +++++++++++++++++++
 mlir/test/lib/Dialect/Test/TestOps.td         |  2 +-
 23 files changed, 101 insertions(+), 71 deletions(-)

diff --git a/flang/lib/Optimizer/Dialect/FIROps.cpp b/flang/lib/Optimizer/Dialect/FIROps.cpp
index 80567b19f9fe5ed..962b87acd5a8050 100644
--- a/flang/lib/Optimizer/Dialect/FIROps.cpp
+++ b/flang/lib/Optimizer/Dialect/FIROps.cpp
@@ -1947,7 +1947,9 @@ void fir::IterWhileOp::print(mlir::OpAsmPrinter &p) {
                 /*printBlockTerminators=*/true);
 }
 
-mlir::Region &fir::IterWhileOp::getLoopBody() { return getRegion(); }
+llvm::SmallVector<mlir::Region *> fir::IterWhileOp::getLoopRegions() {
+  return {&getRegion()};
+}
 
 mlir::BlockArgument fir::IterWhileOp::iterArgToBlockArg(mlir::Value iterArg) {
   for (auto i : llvm::enumerate(getInitArgs()))
@@ -2234,7 +2236,9 @@ void fir::DoLoopOp::print(mlir::OpAsmPrinter &p) {
                 printBlockTerminators);
 }
 
-mlir::Region &fir::DoLoopOp::getLoopBody() { return getRegion(); }
+llvm::SmallVector<mlir::Region *> fir::DoLoopOp::getLoopRegions() {
+  return {&getRegion()};
+}
 
 /// Translate a value passed as an iter_arg to the corresponding block
 /// argument in the body of the loop.
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 8a9ce949a750d43..cbc9449f3c36bb2 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -958,6 +958,7 @@ def ReduceReturnOp :
 def WhileOp : SCF_Op<"while",
     [DeclareOpInterfaceMethods<RegionBranchOpInterface,
         ["getEntrySuccessorOperands"]>,
+     DeclareOpInterfaceMethods<LoopLikeOpInterface>,
      RecursiveMemoryEffects, SingleBlock]> {
   let summary = "a generic 'while' loop";
   let description = [{
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index 9ccc97251e7e669..44d32dd609fc9d5 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -36,15 +36,15 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
       /*args=*/(ins "::mlir::Value ":$value),
       /*methodBody=*/"",
       /*defaultImplementation=*/[{
-        return value.getParentRegion()->isProperAncestor(&$_op.getLoopBody());
+        return !$_op->isAncestor(value.getParentRegion()->getParentOp());
       }]
     >,
     InterfaceMethod<[{
-        Returns the region that makes up the body of the loop and should be
+        Returns the regions that make up the body of the loop and should be
         inspected for loop-invariant operations.
       }],
-      /*retTy=*/"::mlir::Region &",
-      /*methodName=*/"getLoopBody"
+      /*retTy=*/"::llvm::SmallVector<::mlir::Region *>",
+      /*methodName=*/"getLoopRegions"
     >,
     InterfaceMethod<[{
         Moves the given loop-invariant operation out of the loop.
diff --git a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
index e54675967d82552..c7b816eb28faf5f 100644
--- a/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
+++ b/mlir/include/mlir/Transforms/LoopInvariantCodeMotionUtils.h
@@ -11,12 +11,13 @@
 
 #include "mlir/Support/LLVM.h"
 
+#include "llvm/ADT/SmallVector.h"
+
 namespace mlir {
 
 class LoopLikeOpInterface;
 class Operation;
 class Region;
-class RegionRange;
 class Value;
 
 /// Given a list of regions, perform loop-invariant code motion. An operation is
@@ -61,7 +62,7 @@ class Value;
 ///
 /// Returns the number of operations moved.
 size_t moveLoopInvariantCode(
-    RegionRange regions,
+    ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
     function_ref<void(Operation *, Region *)> moveOutOfRegion);
diff --git a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
index 81a6378fc7e49e3..f6e3053b8ae6a96 100644
--- a/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
+++ b/mlir/lib/Conversion/SCFToSPIRV/SCFToSPIRV.cpp
@@ -163,7 +163,7 @@ struct ForOpConversion final : SCFToSPIRVPattern<scf::ForOp> {
     signatureConverter.remapInput(0, newIndVar);
     for (unsigned i = 1, e = body->getNumArguments(); i < e; i++)
       signatureConverter.remapInput(i, header->getArgument(i));
-    body = rewriter.applySignatureConversion(&forOp.getLoopBody(),
+    body = rewriter.applySignatureConversion(&forOp.getRegion(),
                                              signatureConverter);
 
     // Move the blocks from the forOp into the loopOp. This is the body of the
diff --git a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
index c8871c945cbe759..f0412648608a6e4 100644
--- a/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
+++ b/mlir/lib/Conversion/VectorToGPU/VectorToGPU.cpp
@@ -1103,7 +1103,7 @@ convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
 }
 
 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
-// updated and needs to be updated separatly for the loop to be correct.
+// updated and needs to be updated separately for the loop to be correct.
 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
                                                scf::ForOp loop,
                                                ValueRange newInitArgs) {
@@ -1119,9 +1119,8 @@ static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
       operands);
   newLoop.getBody()->erase();
 
-  newLoop.getLoopBody().getBlocks().splice(
-      newLoop.getLoopBody().getBlocks().begin(),
-      loop.getLoopBody().getBlocks());
+  newLoop.getRegion().getBlocks().splice(
+      newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
   for (Value operand : newInitArgs)
     newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
 
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 4455cc88e65e55d..5899c198b703b5e 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2380,8 +2380,7 @@ void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
 /// induction variable. AffineForOp only has one region, so zero is the only
 /// valid value for `index`.
 OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
-  assert((point.isParent() || point == getLoopBody()) &&
-         "invalid region point");
+  assert((point.isParent() || point == getRegion()) && "invalid region point");
 
   // The initial operands map to the loop arguments after the induction
   // variable or are forwarded to the results when the trip count is zero.
@@ -2395,8 +2394,7 @@ OperandRange AffineForOp::getEntrySuccessorOperands(RegionBranchPoint point) {
 /// not a constant.
 void AffineForOp::getSuccessorRegions(
     RegionBranchPoint point, SmallVectorImpl<RegionSuccessor> &regions) {
-  assert((point.isParent() || point == getLoopBody()) &&
-         "expected loop region");
+  assert((point.isParent() || point == getRegion()) && "expected loop region");
   // The loop may typically branch back to its body or to the parent operation.
   // If the predecessor is the parent op and the trip count is known to be at
   // least one, branch into the body using the iterator arguments. And in cases
@@ -2404,7 +2402,7 @@ void AffineForOp::getSuccessorRegions(
   std::optional<uint64_t> tripCount = getTrivialConstantTripCount(*this);
   if (point.isParent() && tripCount.has_value()) {
     if (tripCount.value() > 0) {
-      regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
+      regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
       return;
     }
     if (tripCount.value() == 0) {
@@ -2422,7 +2420,7 @@ void AffineForOp::getSuccessorRegions(
 
   // In all other cases, the loop may branch back to itself or the parent
   // operation.
-  regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
+  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
   regions.push_back(RegionSuccessor(getResults()));
 }
 
@@ -2561,7 +2559,7 @@ bool AffineForOp::matchingBoundOperandList() {
   return true;
 }
 
-Region &AffineForOp::getLoopBody() { return getRegion(); }
+SmallVector<Region *> AffineForOp::getLoopRegions() { return {&getRegion()}; }
 
 std::optional<Value> AffineForOp::getSingleInductionVar() {
   return getInductionVar();
@@ -2758,9 +2756,9 @@ AffineForOp mlir::affine::replaceForOpWithNewYields(OpBuilder &b,
       b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
                             loop.getStep(), operands);
   // Take the body of the original parent loop.
-  newLoop.getLoopBody().takeBody(loop.getLoopBody());
+  newLoop.getRegion().takeBody(loop.getRegion());
   for (Value val : newIterArgs)
-    newLoop.getLoopBody().addArgument(val.getType(), val.getLoc());
+    newLoop.getRegion().addArgument(val.getType(), val.getLoc());
 
   // Update yield operation with new values to be added.
   if (!newYieldedValues.empty()) {
@@ -3848,7 +3846,9 @@ void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
     ensureTerminator(*bodyRegion, builder, result.location);
 }
 
-Region &AffineParallelOp::getLoopBody() { return getRegion(); }
+SmallVector<Region *> AffineParallelOp::getLoopRegions() {
+  return {&getRegion()};
+}
 
 unsigned AffineParallelOp::getNumDims() { return getSteps().size(); }
 
diff --git a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
index c9b7f25c545cd1e..6b3776533bed4ae 100644
--- a/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
+++ b/mlir/lib/Dialect/Affine/Transforms/AffineLoopInvariantCodeMotion.cpp
@@ -85,11 +85,11 @@ static bool isOpLoopInvariant(Operation &op, Value indVar, ValueRange iterArgs,
                                       opsToHoist))
       return false;
   } else if (auto forOp = dyn_cast<AffineForOp>(op)) {
-    if (!areAllOpsInTheBlockListInvariant(forOp.getLoopBody(), indVar, iterArgs,
+    if (!areAllOpsInTheBlockListInvariant(forOp.getRegion(), indVar, iterArgs,
                                           opsWithUsers, opsToHoist))
       return false;
   } else if (auto parOp = dyn_cast<AffineParallelOp>(op)) {
-    if (!areAllOpsInTheBlockListInvariant(parOp.getLoopBody(), indVar, iterArgs,
+    if (!areAllOpsInTheBlockListInvariant(parOp.getRegion(), indVar, iterArgs,
                                           opsWithUsers, opsToHoist))
       return false;
   } else if (!isMemoryEffectFree(&op) &&
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 7bd25dbbca6365a..12a28c2e23b221a 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -429,7 +429,7 @@ static ParallelComputeFunction createParallelComputeFunction(
       mapping.map(op.getInductionVars(), computeBlockInductionVars);
       mapping.map(computeFuncType.captures, captures);
 
-      for (auto &bodyOp : op.getLoopBody().getOps())
+      for (auto &bodyOp : op.getRegion().getOps())
         b.clone(bodyOp, mapping);
     };
   };
@@ -732,7 +732,7 @@ AsyncParallelForRewrite::matchAndRewrite(scf::ParallelOp op,
 
   // Make sure that all constants will be inside the parallel operation body to
   // reduce the number of parallel compute function arguments.
-  cloneConstantsIntoTheRegion(op.getLoopBody(), rewriter);
+  cloneConstantsIntoTheRegion(op.getRegion(), rewriter);
 
   // Compute trip count for each loop induction variable:
   //   tripCount = ceil_div(upperBound - lowerBound, step);
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 2ee21099cfb14c7..7c6639304d97c58 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -219,8 +219,7 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
             // Replace all uses of the `transferRead` with the corresponding
             // basic block argument.
             transferRead.getVector().replaceUsesWithIf(
-                newForOp.getLoopBody().getArguments().back(),
-                [&](OpOperand &use) {
+                newForOp.getBody()->getArguments().back(), [&](OpOperand &use) {
                   Operation *user = use.getOwner();
                   return newForOp->isProperAncestor(user);
                 });
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
index 5a44a85c95c75a2..72b684aaa864c7a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
@@ -199,9 +199,9 @@ static void replaceIndexOpsByInductionVariables(RewriterBase &rewriter,
   // Replace the index operations in the body of the innermost loop op.
   if (!loopOps.empty()) {
     auto loopOp = cast<LoopLikeOpInterface>(loopOps.back());
-    for (IndexOp indexOp :
-         llvm::make_early_inc_range(loopOp.getLoopBody().getOps<IndexOp>()))
-      rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
+    for (Region *r : loopOp.getLoopRegions())
+      for (IndexOp indexOp : llvm::make_early_inc_range(r->getOps<IndexOp>()))
+        rewriter.replaceOp(indexOp, allIvs[indexOp.getDim()]);
   }
 }
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
index f4556787668d45b..7ab4ea41a2cd89d 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
@@ -303,7 +303,7 @@ static Operation *isTensorChunkAccessedByUnknownOp(Operation *writeOp,
       // pass-through tensor arguments left from previous level of
       // hoisting.
       if (auto forUser = dyn_cast<scf::ForOp>(user)) {
-        Value arg = forUser.getLoopBody().getArgument(
+        Value arg = forUser.getBody()->getArgument(
             use.getOperandNumber() - forUser.getNumControlOperands() +
             /*iv value*/ 1);
         uses.push_back(arg.getUses());
diff --git a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
index eb1df2a87b99a7d..397bd5856bcb07c 100644
--- a/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
+++ b/mlir/lib/Dialect/MemRef/Transforms/MultiBuffer.cpp
@@ -152,8 +152,9 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
   std::optional<Value> inductionVar = candidateLoop.getSingleInductionVar();
   std::optional<OpFoldResult> lowerBound = candidateLoop.getSingleLowerBound();
   std::optional<OpFoldResult> singleStep = candidateLoop.getSingleStep();
-  if (!inductionVar || !lowerBound || !singleStep) {
-    LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb or step\n");
+  if (!inductionVar || !lowerBound || !singleStep ||
+      !llvm::hasSingleElement(candidateLoop.getLoopRegions())) {
+    LLVM_DEBUG(DBGS() << "Skip alloc: no single iv, lb, step or region\n");
     return failure();
   }
 
@@ -184,7 +185,8 @@ mlir::memref::multiBuffer(RewriterBase &rewriter, memref::AllocOp allocOp,
 
   // 3. Within the loop, build the modular leading index (i.e. each loop
   // iteration %iv accesses slice ((%iv - %lb) / %step) % %mb_factor).
-  rewriter.setInsertionPointToStart(&candidateLoop.getLoopBody().front());
+  rewriter.setInsertionPointToStart(
+      &candidateLoop.getLoopRegions().front()->front());
   Value ivVal = *inductionVar;
   Value lbVal = getValueOrCreateConstantIndexOp(rewriter, loc, *lowerBound);
   Value stepVal = getValueOrCreateConstantIndexOp(rewriter, loc, *singleStep);
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 2a760c76d2f6867..45e68e23a71d60e 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -530,7 +530,7 @@ ParseResult ForOp::parse(OpAsmParser &parser, OperationState &result) {
   return success();
 }
 
-Region &ForOp::getLoopBody() { return getRegion(); }
+SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
 
 ForOp mlir::scf::getForInductionVarOwner(Value val) {
   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
@@ -558,11 +558,11 @@ void ForOp::getSuccessorRegions(RegionBranchPoint point,
   // Both the operation itself and the region may be branching into the body or
   // back into the operation itself. It is possible for loop not to enter the
   // body.
-  regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
+  regions.push_back(RegionSuccessor(&getRegion(), getRegionIterArgs()));
   regions.push_back(RegionSuccessor(getResults()));
 }
 
-Region &ForallOp::getLoopBody() { return getRegion(); }
+SmallVector<Region *> ForallOp::getLoopRegions() { return {&getRegion()}; }
 
 /// Promotes the loop body of a forallOp to its containing block if it can be
 /// determined that the loop has a single iteration.
@@ -894,7 +894,7 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
       blockArgs.reserve(op.getInitArgs().size() + 1);
       blockArgs.push_back(op.getLowerBound());
       llvm::append_range(blockArgs, op.getInitArgs());
-      replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
+      replaceOpWithRegion(rewriter, op, op.getRegion(), blockArgs);
       return success();
     }
 
@@ -2872,7 +2872,7 @@ void ParallelOp::print(OpAsmPrinter &p) {
       /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
 }
 
-Region &ParallelOp::getLoopBody() { return getRegion(); }
+SmallVector<Region *> ParallelOp::getLoopRegions() { return {&getRegion()}; }
 
 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
@@ -2926,7 +2926,7 @@ struct ParallelOpSingleOrZeroIterationDimsFolder
       // loop body and nested ReduceOp's
       SmallVector<Value> results;
       results.reserve(op.getInitVals().size());
-      for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
+      for (auto &bodyOp : op.getBody()->without_terminator()) {
         auto reduce = dyn_cast<ReduceOp>(bodyOp);
         if (!reduce) {
           rewriter.clone(bodyOp, mapping);
@@ -2965,7 +2965,7 @@ struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
 
   LogicalResult matchAndRewrite(ParallelOp op,
                                 PatternRewriter &rewriter) const override {
-    Block &outerBody = op.getLoopBody().front();
+    Block &outerBody = *op.getBody();
     if (!llvm::hasSingleElement(outerBody.without_terminator()))
       return failure();
 
@@ -2985,7 +2985,7 @@ struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
 
     auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
                            ValueRange iterVals, ValueRange) {
-      Block &innerBody = innerOp.getLoopBody().front();
+      Block &innerBody = *innerOp.getBody();
       assert(iterVals.size() ==
              (outerBody.getNumArguments() + innerBody.getNumArguments()));
       IRMapping mapping;
@@ -3203,6 +3203,10 @@ void WhileOp::getSuccessorRegions(RegionBranchPoint point,
   regions.emplace_back(&getAfter(), getAfter().getArguments());
 }
 
+SmallVector<Region *> WhileOp::getLoopRegions() {
+  return {&getBefore(), &getAfter()};
+}
+
 /// Parses a `while` op.
 ///
 /// op ::= `scf.while` assignments `:` function-type region `do` region
diff --git a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
index 47f6da25b325144..88c6f3da656f3ba 100644
--- a/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/IR/ValueBoundsOpInterfaceImpl.cpp
@@ -35,11 +35,8 @@ struct ForOpInterface
 
     // An EQ constraint can be added if the yielded value (dimension size)
     // equals the corresponding block argument (dimension size).
-    assert(forOp.getLoopBody().hasOneBlock() &&
-           "multiple blocks not supported");
-    Value yieldedValue =
-        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator())
-            .getOperand(iterArgIdx);
+    Value yieldedValue = cast<scf::YieldOp>(forOp.getBody()->getTerminator())
+                             .getOperand(iterArgIdx);
     Value iterArg = forOp.getRegionIterArg(iterArgIdx);
     Value initArg = forOp.getInitArgs()[iterArgIdx];
 
@@ -68,7 +65,7 @@ struct ForOpInterface
           // Stop when reaching a value that is defined outside of the loop. It
           // is impossible to reach an iter_arg from there.
           Operation *op = v.getDefiningOp();
-          return forOp.getLoopBody().findAncestorOpInRegion(*op) == nullptr;
+          return forOp.getRegion().findAncestorOpInRegion(*op) == nullptr;
         });
     if (failed(status))
       return;
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index e09e14dbeb2c7f3..bcbc693a9742ccc 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -489,8 +489,7 @@ struct ForOpInterface
     auto forOp = cast<scf::ForOp>(op);
     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-    auto yieldOp =
-        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
+    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
     bool equivalentYield = state.areEquivalentBufferizedValues(
         bbArg, yieldOp->getOperand(opResult.getResultNumber()));
     return equivalentYield ? BufferRelation::Equivalent
@@ -525,8 +524,7 @@ struct ForOpInterface
     // satisfied. Otherwise, we cannot be sure and must yield a new buffer copy.
     // (New buffer copies do not alias with any buffer.)
     auto forOp = cast<scf::ForOp>(op);
-    auto yieldOp =
-        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
+    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
     OpBuilder::InsertionGuard g(rewriter);
     rewriter.setInsertionPoint(yieldOp);
 
@@ -578,8 +576,7 @@ struct ForOpInterface
             .getResultNumber();
 
     // Compute the bufferized type.
-    auto yieldOp =
-        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
+    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
     Value yieldedValue = yieldOp.getOperand(resultNum);
     BlockArgument iterArg = forOp.getRegionIterArgs()[resultNum];
     Value initArg = forOp.getInitArgs()[resultNum];
@@ -590,7 +587,7 @@ struct ForOpInterface
   LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
                           const BufferizationOptions &options) const {
     auto forOp = cast<scf::ForOp>(op);
-    Block *oldLoopBody = &forOp.getLoopBody().front();
+    Block *oldLoopBody = forOp.getBody();
 
     // Indices of all iter_args that have tensor type. These are the ones that
     // are bufferized.
@@ -624,7 +621,7 @@ struct ForOpInterface
         forOp.getLoc(), forOp.getLowerBound(), forOp.getUpperBound(),
         forOp.getStep(), castedInitArgs);
     newForOp->setAttrs(forOp->getAttrs());
-    Block *loopBody = &newForOp.getLoopBody().front();
+    Block *loopBody = newForOp.getBody();
 
     // Set up new iter_args. The loop body uses tensors, so wrap the (memref)
     // iter_args of the new loop in ToTensorOps.
@@ -657,8 +654,7 @@ struct ForOpInterface
       return success();
 
     auto forOp = cast<scf::ForOp>(op);
-    auto yieldOp =
-        cast<scf::YieldOp>(forOp.getLoopBody().front().getTerminator());
+    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
     for (OpResult opResult : op->getOpResults()) {
       if (!isa<TensorType>(opResult.getType()))
         continue;
diff --git a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
index 35d242327178e7f..7932c38a3e8d8bb 100644
--- a/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/StructuralTypeConversions.cpp
@@ -126,7 +126,7 @@ class ConvertForOpTypes
     // new op's regions doesn't remove the child ops from the worklist).
 
     // convertRegionTypes already takes care of 1:N conversion.
-    if (failed(rewriter.convertRegionTypes(&op.getLoopBody(), *typeConverter)))
+    if (failed(rewriter.convertRegionTypes(&op.getRegion(), *typeConverter)))
       return std::nullopt;
 
     // Unpacked the iteration arguments.
@@ -146,8 +146,8 @@ class ConvertForOpTypes
     // We do not need the empty block created by rewriter.
     rewriter.eraseBlock(newOp.getBody(0));
     // Inline the type converted region from the original operation.
-    rewriter.inlineRegionBefore(op.getLoopBody(), newOp.getLoopBody(),
-                                newOp.getLoopBody().end());
+    rewriter.inlineRegionBefore(op.getRegion(), newOp.getRegion(),
+                                newOp.getRegion().end());
 
     return newOp;
   }
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 222a9aa395c4f09..411503700eb01c3 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -1006,9 +1006,9 @@ scf::ForallOp mlir::fuseIndependentSiblingForallLoops(scf::ForallOp target,
 
   // Append everything except the terminator into the fused operation.
   rewriter.setInsertionPointToStart(fusedLoop.getBody());
-  for (Operation &op : target.getLoopBody().begin()->without_terminator())
+  for (Operation &op : target.getBody()->without_terminator())
     rewriter.clone(op, fusedMapping);
-  for (Operation &op : source.getLoopBody().begin()->without_terminator())
+  for (Operation &op : source.getBody()->without_terminator())
     rewriter.clone(op, fusedMapping);
 
   // Fuse the old terminator in_parallel ops into the new one.
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
index 737058c543dacef..efdd3347558b44b 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/SparseGPUCodegen.cpp
@@ -300,8 +300,8 @@ static void genGPUCode(PatternRewriter &rewriter, gpu::GPUFuncOp gpuFunc,
   //   }
   Value upper = irMap.lookup(forallOp.getUpperBound()[0]);
   scf::ForOp forOp = rewriter.create<scf::ForOp>(loc, row, upper, inc);
-  rewriter.cloneRegionBefore(forallOp.getLoopBody(), forOp.getLoopBody(),
-                             forOp.getLoopBody().begin(), irMap);
+  rewriter.cloneRegionBefore(forallOp.getRegion(), forOp.getRegion(),
+                             forOp.getRegion().begin(), irMap);
 
   // Done.
   rewriter.setInsertionPointAfter(forOp);
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 0696bf386bf752f..616aad8c4aaf08f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -69,7 +69,7 @@ struct TosaInlinerInterface : public DialectInlinerInterface {
 //===----------------------------------------------------------------------===//
 
 /// Returns the while loop body.
-Region &tosa::WhileOp::getLoopBody() { return getBody(); }
+SmallVector<Region *> tosa::WhileOp::getLoopRegions() { return {&getBody()}; }
 
 //===----------------------------------------------------------------------===//
 // Tosa dialect initialization.
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index d546ab3be8bda33..080492da6ae4b97 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -48,7 +48,7 @@ static bool canBeHoisted(Operation *op,
 }
 
 size_t mlir::moveLoopInvariantCode(
-    RegionRange regions,
+    ArrayRef<Region *> regions,
     function_ref<bool(Value, Region *)> isDefinedOutsideRegion,
     function_ref<bool(Operation *, Region *)> shouldMoveOutOfRegion,
     function_ref<void(Operation *, Region *)> moveOutOfRegion) {
@@ -96,7 +96,7 @@ size_t mlir::moveLoopInvariantCode(
 
 size_t mlir::moveLoopInvariantCode(LoopLikeOpInterface loopLike) {
   return moveLoopInvariantCode(
-      &loopLike.getLoopBody(),
+      loopLike.getLoopRegions(),
       [&](Value value, Region *) {
         return loopLike.isDefinedOutsideOfLoop(value);
       },
diff --git a/mlir/test/Transforms/loop-invariant-code-motion.mlir b/mlir/test/Transforms/loop-invariant-code-motion.mlir
index 4526b23fa503dbf..e9b1e2235900103 100644
--- a/mlir/test/Transforms/loop-invariant-code-motion.mlir
+++ b/mlir/test/Transforms/loop-invariant-code-motion.mlir
@@ -929,3 +929,30 @@ func.func @speculate_dynamic_pack_and_unpack(%source: tensor<?x?xf32>,
   }
   return
 }
+
+// -----
+
+// CHECK-LABEL: func @hoist_from_scf_while(
+//  CHECK-SAME:     %[[arg0:.*]]: i32, %{{.*}}: i32)
+//   CHECK-DAG:   arith.constant 1 : i32
+//   CHECK-DAG:   %[[c2:.*]] = arith.constant 2 : i32
+//   CHECK-DAG:   %[[c10:.*]] = arith.constant 10 : i32
+//   CHECK-DAG:   %[[added:.*]] = arith.addi %[[arg0]], %[[c2]]
+//       CHECK:   scf.while
+//       CHECK:     %[[cmpi:.*]] = arith.cmpi slt, %{{.*}}, %[[added]]
+//       CHECK:     scf.condition(%[[cmpi]])
+func.func @hoist_from_scf_while(%arg0: i32, %arg1: i32) -> i32 {
+  %0 = scf.while (%arg2 = %arg1) : (i32) -> (i32) {
+    %c2 = arith.constant 2 : i32
+    %c10 = arith.constant 10 : i32
+    %added = arith.addi %arg0, %c2 : i32
+    %1 = arith.cmpi slt, %arg2, %added : i32
+    scf.condition(%1) %arg2 : i32
+  } do {
+  ^bb0(%arg2: i32):
+    %c1 = arith.constant 1 : i32
+    %added2 = arith.addi %c1, %arg2 : i32
+    scf.yield %added2 : i32
+  }
+  return %0 : i32
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index 0aa8ce4de9756fa..354a43c244e3bbe 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -2570,7 +2570,7 @@ def TestGraphLoopOp : TEST_Op<"graph_loop",
   }];
 
   let extraClassDeclaration = [{
-    mlir::Region &getLoopBody() { return getBody(); }
+    llvm::SmallVector<mlir::Region *> getLoopRegions() { return {&getBody()}; }
   }];
 }
 



More information about the flang-commits mailing list