[Mlir-commits] [mlir] [mlir][Interfaces] `LoopLikeOpInterface`: Add helper to get yielded values (PR #67305)

Matthias Springer llvmlistbot at llvm.org
Mon Sep 25 02:35:50 PDT 2023


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/67305

Add a new interface method that returns the yielded values.
    
Also add a verifier that checks the number of inits/iter_args/yielded values. Most of the checked invariants (but not all of them) are already covered by the `RegionBranchOpInterface`, but the `LoopLikeOpInterface` now provides (additional) error messages that are easier to read.

Depends on #67121. Only review the top commit.


>From dd90d55bf1771b89843f9d0add21a427fae260ae Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 25 Sep 2023 11:34:11 +0200
Subject: [PATCH 1/2] [mlir][Interfaces] `LoopLikeOpInterface`: Add
 `replaceWithAdditionalYields`

`affine::replaceForOpWithNewYields` and `replaceLoopWithNewYields` (for "scf.for") are now interface methods and additional loop-carried variables can now be added to "scf.for"/"affine.for" uniformly. (No more `TypeSwitch` needed.)

Note: `scf.while` and other loops with loop-carried variables can implement `replaceWithAdditionalYields`, but to keep this commit small, that is not done in this commit.

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 .../mlir/Dialect/Affine/IR/AffineOps.h        | 14 ---
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  2 +-
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    |  3 +-
 mlir/include/mlir/Dialect/SCF/Utils/Utils.h   | 42 +-------
 .../mlir/Interfaces/LoopLikeInterface.h       |  7 ++
 .../mlir/Interfaces/LoopLikeInterface.td      | 43 +++++++++
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      | 96 ++++++++++---------
 .../Dialect/Affine/Utils/LoopFusionUtils.cpp  | 20 ++--
 mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp   | 22 +++--
 .../Linalg/Transforms/HoistPadding.cpp        |  7 +-
 .../Dialect/Linalg/Transforms/Hoisting.cpp    | 53 +++-------
 .../Linalg/Transforms/SubsetHoisting.cpp      | 20 ++--
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 53 ++++++++++
 .../SCF/Transforms/TileUsingInterface.cpp     |  3 +-
 mlir/lib/Dialect/SCF/Utils/Utils.cpp          | 91 +++---------------
 .../scf-replace-with-new-yields.mlir          |  3 -
 mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp    | 18 ++--
 17 files changed, 241 insertions(+), 256 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
index 704c2704536d20a..56b4a609e62c001 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
@@ -490,20 +490,6 @@ void buildAffineLoopNest(OpBuilder &builder, Location loc, ValueRange lbs,
                          function_ref<void(OpBuilder &, Location, ValueRange)>
                              bodyBuilderFn = nullptr);
 
-/// Replace `loop` with a new loop where `newIterOperands` are appended with
-/// new initialization values and `newYieldedValues` are added as new yielded
-/// values. The returned ForOp has `newYieldedValues.size()` new result values.
-/// Additionally, if `replaceLoopResults` is true, all uses of
-/// `loop.getResults()` are replaced with the first `loop.getNumResults()`
-/// return values  of the original loop respectively. The original loop is
-/// deleted and the new loop returned.
-/// Prerequisite: `newIterOperands.size() == newYieldedValues.size()`.
-AffineForOp replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop,
-                                      ValueRange newIterOperands,
-                                      ValueRange newYieldedValues,
-                                      ValueRange newIterArgs,
-                                      bool replaceLoopResults = true);
-
 /// AffineBound represents a lower or upper bound in the for operation.
 /// This class does not own the underlying operands. Instead, it refers
 /// to the operands stored in the AffineForOp. Its life span should not exceed
diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index 5a1baaf4e1611c8..d8ef0506d0822d7 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -120,7 +120,7 @@ def AffineForOp : Affine_Op<"for",
     [AutomaticAllocationScope, ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
      ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-      "getSingleUpperBound"]>,
+      "getSingleUpperBound", "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
   let summary = "for operation";
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index 0c93989ca99a4eb..e62015b84888a18 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -121,7 +121,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInits", "getSingleInductionVar", "getSingleLowerBound",
-        "getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration"]>,
+        "getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration",
+        "replaceWithAdditionalYields"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
diff --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bde30c9c3528dbc..9bdd6eb833876f0 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -34,39 +34,6 @@ class CallOp;
 class FuncOp;
 } // namespace func
 
-/// Replace the `loop` with `newIterOperands` added as new initialization
-/// values. `newYieldValuesFn` is a callback that can be used to specify
-/// the additional values to be yielded by the loop. The number of
-/// values returned by the callback should match the number of new
-/// initialization values. This function
-/// - Moves (i.e. doesnt clone) operations from the `loop` to the newly created
-///   loop
-/// - Replaces the uses of `loop` with the new loop.
-/// - `loop` isnt erased, but is left in a "no-op" state where the body of the
-///   loop just yields the basic block arguments that correspond to the
-///   initialization values of a loop. The loop is dead after this method.
-/// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
-///   `newIterOperands` within the generated new loop are replaced
-///   with the corresponding `BlockArgument` in the loop body.
-using NewYieldValueFn = std::function<SmallVector<Value>(
-    OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs)>;
-scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
-                                    ValueRange newIterOperands,
-                                    const NewYieldValueFn &newYieldValuesFn,
-                                    bool replaceIterOperandsUsesInLoop = true);
-// Simpler API if the new yields are just a list of values that can be
-// determined ahead of time.
-inline scf::ForOp
-replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
-                         ValueRange newIterOperands, ValueRange newYields,
-                         bool replaceIterOperandsUsesInLoop = true) {
-  auto fn = [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
-    return SmallVector<Value>(newYields.begin(), newYields.end());
-  };
-  return replaceLoopWithNewYields(builder, loop, newIterOperands, fn,
-                                  replaceIterOperandsUsesInLoop);
-}
-
 /// Update a perfectly nested loop nest to yield new values from the innermost
 /// loop and propagating it up through the loop nest. This function
 /// - Expects `loopNest` to be a perfectly nested loop with outer most loop
@@ -82,11 +49,10 @@ replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
 /// - If `replaceIterOperandsUsesInLoop` is true, all uses of the
 ///   `newIterOperands` within the generated new loop are replaced with the
 ///   corresponding `BlockArgument` in the loop body.
-SmallVector<scf::ForOp>
-replaceLoopNestWithNewYields(OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
-                             ValueRange newIterOperands,
-                             const NewYieldValueFn &newYieldValueFn,
-                             bool replaceIterOperandsUsesInLoop = true);
+SmallVector<scf::ForOp> replaceLoopNestWithNewYields(
+    RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
+    ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
+    bool replaceIterOperandsUsesInLoop = true);
 
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 9d81a61fac88566..0eebb984e5897ae 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -17,6 +17,13 @@
 
 namespace mlir {
 class RewriterBase;
+
+/// A function that returns the additional yielded values during
+/// `replaceWithAdditionalYields`. `newBbArgs` are the newly added region
+/// iter_args. This function should return as many values as there are block
+/// arguments in `newBbArgs`.
+using NewYieldValuesFn = std::function<SmallVector<Value>(
+    OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
 } // namespace mlir
 
 /// Include the generated interface declarations.
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index cb6b2f4ed4ae8b5..ded0a29292ff6f0 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -141,6 +141,31 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         return ::mlir::Block::BlockArgListType();
       }]
     >,
+    InterfaceMethod<[{
+        Append the specified additional "init" operands: replace this loop with
+        a new loop that has the additional init operands. The loop body of
+        this loop is moved over to the new loop.
+
+        `newInitOperands` specifies the additional "init" operands.
+        `newYieldValuesFn` is a function that returns the yielded values (which
+        can be computed based on the additional region iter_args). If
+        `replaceInitOperandUsesInLoop` is set, all uses of the additional init
+        operands inside of this loop are replaced with the corresponding, newly
+        added region iter_args.
+
+        Note: Loops that do not support init/iter_args should return "failure".
+      }],
+      /*retTy=*/"::mlir::FailureOr<::mlir::LoopLikeOpInterface>",
+      /*methodName=*/"replaceWithAdditionalYields",
+      /*args=*/(ins "::mlir::RewriterBase &":$rewriter,
+                    "::mlir::ValueRange":$newInitOperands,
+                    "bool":$replaceInitOperandUsesInLoop,
+                    "const ::mlir::NewYieldValuesFn &":$newYieldValuesFn),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::failure();
+      }]
+    >,
   ];
 
   let extraClassDeclaration = [{
@@ -149,6 +174,24 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
     /// because the control flow graph is cyclic
     static bool blockIsInLoop(Block *block);
   }];
+
+  let extraSharedClassDeclaration = [{
+    /// Append the specified additional "init" operands: replace this loop with
+    /// a new loop that has the additional init operands. The loop body of this
+    /// loop is moved over to the new loop.
+    ///
+    /// The newly added region iter_args are yielded from the loop.
+    ::mlir::FailureOr<::mlir::LoopLikeOpInterface>
+        replaceWithAdditionalIterOperands(::mlir::RewriterBase &rewriter,
+                                          ::mlir::ValueRange newInitOperands,
+                                          bool replaceInitOperandUsesInLoop) {
+      return $_op.replaceWithAdditionalYields(
+          rewriter, newInitOperands, replaceInitOperandUsesInLoop,
+          [](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+            return SmallVector<Value>(newBBArgs);
+          });
+    }
+  }];
 }
 
 #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 9ecc568883a3b0b..6c060c90e24af82 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2575,6 +2575,58 @@ std::optional<OpFoldResult> AffineForOp::getSingleUpperBound() {
   return OpFoldResult(b.getI64IntegerAttr(getConstantUpperBound()));
 }
 
+FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
+    RewriterBase &rewriter, ValueRange newInitOperands,
+    bool replaceInitOperandUsesInLoop,
+    const NewYieldValuesFn &newYieldValuesFn) {
+  // Create a new loop before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(getOperation());
+  auto inits = llvm::to_vector(getInits());
+  inits.append(newInitOperands.begin(), newInitOperands.end());
+  AffineForOp newLoop = rewriter.create<AffineForOp>(
+      getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
+      getUpperBoundOperands(), getUpperBoundMap(), getStep(), inits);
+
+  // Generate the new yield values and append them to the scf.yield operation.
+  auto yieldOp = cast<AffineYieldOp>(getBody()->getTerminator());
+  ArrayRef<BlockArgument> newIterArgs =
+      newLoop.getBody()->getArguments().take_back(newInitOperands.size());
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(yieldOp);
+    SmallVector<Value> newYieldedValues =
+        newYieldValuesFn(rewriter, getLoc(), newIterArgs);
+    assert(newInitOperands.size() == newYieldedValues.size() &&
+           "expected as many new yield values as new iter operands");
+    rewriter.updateRootInPlace(yieldOp, [&]() {
+      yieldOp.getOperandsMutable().append(newYieldedValues);
+    });
+  }
+
+  // Move the loop body to the new op.
+  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
+                       newLoop.getBody()->getArguments().take_front(
+                           getBody()->getNumArguments()));
+
+  if (replaceInitOperandUsesInLoop) {
+    // Replace all uses of `newInitOperands` with the corresponding basic block
+    // arguments.
+    for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
+      rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
+                                 [&](OpOperand &use) {
+                                   Operation *user = use.getOwner();
+                                   return newLoop->isProperAncestor(user);
+                                 });
+    }
+  }
+
+  // Replace the old loop.
+  rewriter.replaceOp(getOperation(),
+                     newLoop->getResults().take_front(getNumResults()));
+  return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
 Speculation::Speculatability AffineForOp::getSpeculatability() {
   // `affine.for (I = Start; I < End; I += 1)` terminates for all values of
   // Start and End.
@@ -2725,50 +2777,6 @@ void mlir::affine::buildAffineLoopNest(
                           buildAffineLoopFromValues);
 }
 
-AffineForOp mlir::affine::replaceForOpWithNewYields(OpBuilder &b,
-                                                    AffineForOp loop,
-                                                    ValueRange newIterOperands,
-                                                    ValueRange newYieldedValues,
-                                                    ValueRange newIterArgs,
-                                                    bool replaceLoopResults) {
-  assert(newIterOperands.size() == newYieldedValues.size() &&
-         "newIterOperands must be of the same size as newYieldedValues");
-  // Create a new loop before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(b);
-  b.setInsertionPoint(loop);
-  auto operands = llvm::to_vector<4>(loop.getInits());
-  operands.append(newIterOperands.begin(), newIterOperands.end());
-  SmallVector<Value, 4> lbOperands(loop.getLowerBoundOperands());
-  SmallVector<Value, 4> ubOperands(loop.getUpperBoundOperands());
-  SmallVector<Value, 4> steps(loop.getStep());
-  auto lbMap = loop.getLowerBoundMap();
-  auto ubMap = loop.getUpperBoundMap();
-  AffineForOp newLoop =
-      b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
-                            loop.getStep(), operands);
-  // Take the body of the original parent loop.
-  newLoop.getRegion().takeBody(loop.getRegion());
-  for (Value val : newIterArgs)
-    newLoop.getRegion().addArgument(val.getType(), val.getLoc());
-
-  // Update yield operation with new values to be added.
-  if (!newYieldedValues.empty()) {
-    auto yield = cast<AffineYieldOp>(newLoop.getBody()->getTerminator());
-    b.setInsertionPoint(yield);
-    auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
-    yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
-    b.create<AffineYieldOp>(yield.getLoc(), yieldOperands);
-    yield.erase();
-  }
-  if (replaceLoopResults) {
-    for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
-                                                    loop.getNumResults()))) {
-      std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
-    }
-  }
-  return newLoop;
-}
-
 //===----------------------------------------------------------------------===//
 // AffineIfOp
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
index 3ecb8664e3fd765..5053b08ee0834cd 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopFusionUtils.cpp
@@ -19,6 +19,7 @@
 #include "mlir/Dialect/Affine/LoopUtils.h"
 #include "mlir/IR/IRMapping.h"
 #include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <optional>
@@ -361,16 +362,22 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
   std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
   if (!tripCount || *tripCount != 1)
     return failure();
-  auto iterOperands = forOp.getInits();
   auto *parentOp = forOp->getParentOp();
   if (!isa<AffineForOp>(parentOp))
     return failure();
-  auto newOperands = forOp.getBody()->getTerminator()->getOperands();
-  OpBuilder b(parentOp);
+  SmallVector<Value> newOperands;
+  llvm::append_range(newOperands,
+                     forOp.getBody()->getTerminator()->getOperands());
+  IRRewriter rewriter(parentOp->getContext());
+  int64_t parentOpNumResults = parentOp->getNumResults();
   // Replace the parent loop and add iteroperands and results from the `forOp`.
   AffineForOp parentForOp = forOp->getParentOfType<AffineForOp>();
-  AffineForOp newLoop = replaceForOpWithNewYields(
-      b, parentForOp, iterOperands, newOperands, forOp.getRegionIterArgs());
+  AffineForOp newLoop =
+      cast<AffineForOp>(*parentForOp.replaceWithAdditionalYields(
+          rewriter, forOp.getInits(), /*replaceInitOperandUsesInLoop=*/false,
+          [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+            return newOperands;
+          }));
 
   // For sibling-fusion users, collect operations that use the results of the
   // `forOp` outside the new parent loop that has absorbed all its iter args
@@ -387,7 +394,7 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
   // Update the results of the `forOp` in the new loop.
   for (unsigned i = 0, e = forOp.getNumResults(); i != e; ++i) {
     forOp.getResult(i).replaceAllUsesWith(
-        newLoop.getResult(i + parentOp->getNumResults()));
+        newLoop.getResult(i + parentOpNumResults));
   }
   // For sibling-fusion users, move operations that use the results of the
   // `forOp` outside the new parent loop
@@ -412,7 +419,6 @@ static LogicalResult promoteSingleIterReductionLoop(AffineForOp forOp,
   parentBlock->getOperations().splice(Block::iterator(forOp),
                                       forOp.getBody()->getOperations());
   forOp.erase();
-  parentForOp.erase();
   return success();
 }
 
diff --git a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
index e6c4b2f8447470c..9d8ed9b4ac93387 100644
--- a/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
+++ b/mlir/lib/Dialect/Affine/Utils/LoopUtils.cpp
@@ -1197,9 +1197,9 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
   // `unrollJamFactor` copies of its iterOperands, iter_args and yield
   // operands.
   SmallVector<AffineForOp, 4> newLoopsWithIterArgs;
-  OpBuilder builder(forOp.getContext());
+  IRRewriter rewriter(forOp.getContext());
   for (AffineForOp oldForOp : loopsWithIterArgs) {
-    SmallVector<Value, 4> dupIterOperands, dupIterArgs, dupYieldOperands;
+    SmallVector<Value> dupIterOperands, dupYieldOperands;
     ValueRange oldIterOperands = oldForOp.getInits();
     ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
     ValueRange oldYieldOperands =
@@ -1208,19 +1208,21 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
     // fix iterOperands and yield operands after cloning of sub-blocks.
     for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
       dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
-      dupIterArgs.append(oldIterArgs.begin(), oldIterArgs.end());
       dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
     }
     // Create a new loop with additional iterOperands, iter_args and yield
     // operands. This new loop will take the loop body of the original loop.
-    AffineForOp newForOp = affine::replaceForOpWithNewYields(
-        builder, oldForOp, dupIterOperands, dupYieldOperands, dupIterArgs);
+    bool forOpReplaced = oldForOp == forOp;
+    AffineForOp newForOp =
+        cast<AffineForOp>(*oldForOp.replaceWithAdditionalYields(
+            rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+            [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+              return dupYieldOperands;
+            }));
     newLoopsWithIterArgs.push_back(newForOp);
     // `forOp` has been replaced with a new loop.
-    if (oldForOp == forOp)
+    if (forOpReplaced)
       forOp = newForOp;
-    assert(oldForOp.use_empty() && "old for op should not have any user");
-    oldForOp.erase();
     // Update `operandMaps` for `newForOp` iterArgs and results.
     ValueRange newIterArgs = newForOp.getRegionIterArgs();
     unsigned oldNumIterArgs = oldIterArgs.size();
@@ -1294,7 +1296,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
     // into one value. For example, for %0:2 = affine.for ... and addf, we add
     // %1 = arith.addf %0#0, %0#1, and replace the following uses of %0#0 with
     // %1.
-    builder.setInsertionPointAfter(forOp);
+    rewriter.setInsertionPointAfter(forOp);
     auto loc = forOp.getLoc();
     unsigned oldNumResults = forOp.getNumResults() / unrollJamFactor;
     for (LoopReduction &reduction : reductions) {
@@ -1305,7 +1307,7 @@ LogicalResult mlir::affine::loopUnrollJamByFactor(AffineForOp forOp,
       for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
         rhs = forOp.getResult(i * oldNumResults + pos);
         // Create ops based on reduction type.
-        lhs = arith::getReductionOp(reduction.kind, builder, loc, lhs, rhs);
+        lhs = arith::getReductionOp(reduction.kind, rewriter, loc, lhs, rhs);
         if (!lhs)
           return failure();
         Operation *op = lhs.getDefiningOp();
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index a9debb7bbc489a4..72bd2b409f5d52b 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -842,8 +842,11 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
         outerSliceOp.getMixedStrides());
     rewriter.replaceAllUsesWith(forOp.getResult(iterArgNumber), extracted);
   }
-  scf::ForOp newForOp =
-      replaceLoopWithNewYields(rewriter, forOp, initArgs, yieldOperands);
+  scf::ForOp newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
+      rewriter, initArgs, /*replaceInitOperandUsesInLoop=*/true,
+      [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs) {
+        return yieldOperands;
+      }));
 
   LLVM_DEBUG(DBGS() << "newForOp results: " << newForOp.getNumResults()
                     << "\n");
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
index 7c6639304d97c58..1c68ca49725effb 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Hoisting.cpp
@@ -191,47 +191,24 @@ void mlir::linalg::hoistRedundantVectorTransfers(func::FuncOp func) {
       transferWrite->moveAfter(loop);
 
       // Rewrite `loop` with new yields by cloning and erase the original loop.
-      OpBuilder b(transferRead);
-      NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
-                                    ArrayRef<BlockArgument> newBBArgs) {
+      IRRewriter rewriter(transferRead.getContext());
+      NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
+                                     ArrayRef<BlockArgument> newBBArgs) {
         return SmallVector<Value>{transferWrite.getVector()};
       };
 
-      // Transfer write has been hoisted, need to update the written vector by
-      // the value yielded by the newForOp.
-      return TypeSwitch<Operation *, WalkResult>(loop)
-          .Case<scf::ForOp>([&](scf::ForOp scfForOp) {
-            auto newForOp = replaceLoopWithNewYields(
-                b, scfForOp, transferRead.getVector(), yieldFn);
-            transferWrite.getVectorMutable().assign(
-                newForOp.getResults().back());
-            changed = true;
-            loop.erase();
-            // Need to interrupt and restart because erasing the loop messes up
-            // the walk.
-            return WalkResult::interrupt();
-          })
-          .Case<affine::AffineForOp>([&](affine::AffineForOp affineForOp) {
-            auto newForOp = replaceForOpWithNewYields(
-                b, affineForOp, transferRead.getVector(),
-                SmallVector<Value>{transferWrite.getVector()},
-                transferWrite.getVector());
-            // Replace all uses of the `transferRead` with the corresponding
-            // basic block argument.
-            transferRead.getVector().replaceUsesWithIf(
-                newForOp.getBody()->getArguments().back(), [&](OpOperand &use) {
-                  Operation *user = use.getOwner();
-                  return newForOp->isProperAncestor(user);
-                });
-            transferWrite.getVectorMutable().assign(
-                newForOp.getResults().back());
-            changed = true;
-            loop.erase();
-            // Need to interrupt and restart because erasing the loop messes up
-            // the walk.
-            return WalkResult::interrupt();
-          })
-          .Default([](Operation *) { return WalkResult::interrupt(); });
+      auto maybeNewLoop = loop.replaceWithAdditionalYields(
+          rewriter, transferRead.getVector(),
+          /*replaceInitOperandUsesInLoop=*/true, yieldFn);
+      if (failed(maybeNewLoop))
+        return WalkResult::interrupt();
+
+      transferWrite.getVectorMutable().assign(
+          maybeNewLoop->getOperation()->getResults().back());
+      changed = true;
+      // Need to interrupt and restart because erasing the loop messes up
+      // the walk.
+      return WalkResult::interrupt();
     });
   }
 }
diff --git a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
index 7ab4ea41a2cd89d..91e0d139ec5c2f0 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/SubsetHoisting.cpp
@@ -363,13 +363,13 @@ static scf::ForOp hoistTransferReadWrite(
 
   // 2. Rewrite `loop` with an additional yield. This is the quantity that is
   // computed iteratively but whose storage has become loop-invariant.
-  NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
-                                ArrayRef<BlockArgument> newBBArgs) {
+  NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
+                                 ArrayRef<BlockArgument> newBBArgs) {
     return SmallVector<Value>{transferWriteOp.getVector()};
   };
-  auto newForOp = replaceLoopWithNewYields(
-      rewriter, forOp, {transferReadOp.getVector()}, yieldFn);
-  rewriter.eraseOp(forOp);
+  auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
+      rewriter, {transferReadOp.getVector()},
+      /*replaceInitOperandUsesInLoop=*/true, yieldFn));
 
   // 3. Update the yield. Invariant: initArgNumber is the destination tensor.
   auto yieldOp =
@@ -425,13 +425,13 @@ static scf::ForOp hoistExtractInsertSlice(RewriterBase &rewriter,
 
   // 2. Rewrite `loop` with an additional yield. This is the quantity that is
   // computed iteratively but whose storage has become loop-invariant.
-  NewYieldValueFn yieldFn = [&](OpBuilder &b, Location loc,
-                                ArrayRef<BlockArgument> newBBArgs) {
+  NewYieldValuesFn yieldFn = [&](OpBuilder &b, Location loc,
+                                 ArrayRef<BlockArgument> newBBArgs) {
     return SmallVector<Value>{insertSliceOp.getSource()};
   };
-  auto newForOp = replaceLoopWithNewYields(rewriter, forOp,
-                                           extractSliceOp.getResult(), yieldFn);
-  rewriter.eraseOp(forOp);
+  auto newForOp = cast<scf::ForOp>(*forOp.replaceWithAdditionalYields(
+      rewriter, extractSliceOp.getResult(),
+      /*replaceInitOperandUsesInLoop=*/true, yieldFn));
 
   // 3. Update the yield. Invariant: initArgNumber is the destination tensor.
   auto yieldOp =
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8788597a1cefcfa..8d8481421e18d57 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -529,6 +529,59 @@ SmallVector<Region *> ForOp::getLoopRegions() { return {&getRegion()}; }
 
 OperandRange ForOp::getInits() { return getInitArgs(); }
 
+FailureOr<LoopLikeOpInterface>
+ForOp::replaceWithAdditionalYields(RewriterBase &rewriter,
+                                   ValueRange newInitOperands,
+                                   bool replaceInitOperandUsesInLoop,
+                                   const NewYieldValuesFn &newYieldValuesFn) {
+  // Create a new loop before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(rewriter);
+  rewriter.setInsertionPoint(getOperation());
+  auto inits = llvm::to_vector(getInitArgs());
+  inits.append(newInitOperands.begin(), newInitOperands.end());
+  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
+      getLoc(), getLowerBound(), getUpperBound(), getStep(), inits,
+      [](OpBuilder &, Location, Value, ValueRange) {});
+
+  // Generate the new yield values and append them to the scf.yield operation.
+  auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
+  ArrayRef<BlockArgument> newIterArgs =
+      newLoop.getBody()->getArguments().take_back(newInitOperands.size());
+  {
+    OpBuilder::InsertionGuard g(rewriter);
+    rewriter.setInsertionPoint(yieldOp);
+    SmallVector<Value> newYieldedValues =
+        newYieldValuesFn(rewriter, getLoc(), newIterArgs);
+    assert(newInitOperands.size() == newYieldedValues.size() &&
+           "expected as many new yield values as new iter operands");
+    rewriter.updateRootInPlace(yieldOp, [&]() {
+      yieldOp.getResultsMutable().append(newYieldedValues);
+    });
+  }
+
+  // Move the loop body to the new op.
+  rewriter.mergeBlocks(getBody(), newLoop.getBody(),
+                       newLoop.getBody()->getArguments().take_front(
+                           getBody()->getNumArguments()));
+
+  if (replaceInitOperandUsesInLoop) {
+    // Replace all uses of `newInitOperands` with the corresponding basic block
+    // arguments.
+    for (auto it : llvm::zip(newInitOperands, newIterArgs)) {
+      rewriter.replaceUsesWithIf(std::get<0>(it), std::get<1>(it),
+                                 [&](OpOperand &use) {
+                                   Operation *user = use.getOwner();
+                                   return newLoop->isProperAncestor(user);
+                                 });
+    }
+  }
+
+  // Replace the old loop.
+  rewriter.replaceOp(getOperation(),
+                     newLoop->getResults().take_front(getNumResults()));
+  return cast<LoopLikeOpInterface>(newLoop.getOperation());
+}
+
 ForOp mlir::scf::getForInductionVarOwner(Value val) {
   auto ivArg = llvm::dyn_cast<BlockArgument>(val);
   if (!ivArg)
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index ab59eac2ac4d6f8..e2fb417f5f07a6d 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -174,7 +174,7 @@ yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
                  ArrayRef<SmallVector<OpFoldResult>> tileOffsetsList,
                  ArrayRef<SmallVector<OpFoldResult>> tileSizesList,
                  MutableArrayRef<scf::ForOp> loops) {
-  NewYieldValueFn yieldValueFn =
+  NewYieldValuesFn yieldValueFn =
       [&](OpBuilder &b, Location loc,
           ArrayRef<BlockArgument> newBBArgs) -> SmallVector<Value> {
     SmallVector<Value> inserts;
@@ -196,7 +196,6 @@ yieldTiledValues(RewriterBase &rewriter, ValueRange initValues,
       replaceLoopNestWithNewYields(rewriter, loops, initValues, yieldValueFn,
                                    /*replaceIterOperandsUsesInLoop =*/false);
   for (const auto &loop : llvm::enumerate(loops)) {
-    rewriter.eraseOp(loop.value());
     loops[loop.index()] = newLoops[loop.index()];
   }
   return llvm::to_vector(llvm::map_range(
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 411503700eb01c3..5360c493f8f8d71 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -38,77 +38,9 @@ struct LoopParams {
 };
 } // namespace
 
-scf::ForOp
-mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
-                               ValueRange newIterOperands,
-                               const NewYieldValueFn &newYieldValuesFn,
-                               bool replaceIterOperandsUsesInLoop) {
-  // Create a new loop before the existing one, with the extra operands.
-  OpBuilder::InsertionGuard g(builder);
-  builder.setInsertionPoint(loop);
-  auto operands = llvm::to_vector(loop.getInitArgs());
-  llvm::append_range(operands, newIterOperands);
-  scf::ForOp newLoop = builder.create<scf::ForOp>(
-      loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
-      operands, [](OpBuilder &, Location, Value, ValueRange) {});
-
-  Block *loopBody = loop.getBody();
-  Block *newLoopBody = newLoop.getBody();
-
-  // Move the body of the original loop to the new loop.
-  newLoopBody->getOperations().splice(newLoopBody->end(),
-                                      loopBody->getOperations());
-
-  // Generate the new yield values to use by using the callback and append the
-  // yield values to the scf.yield operation.
-  auto yield = cast<scf::YieldOp>(newLoopBody->getTerminator());
-  ArrayRef<BlockArgument> newBBArgs =
-      newLoopBody->getArguments().take_back(newIterOperands.size());
-  {
-    OpBuilder::InsertionGuard g(builder);
-    builder.setInsertionPoint(yield);
-    SmallVector<Value> newYieldedValues =
-        newYieldValuesFn(builder, loop.getLoc(), newBBArgs);
-    assert(newIterOperands.size() == newYieldedValues.size() &&
-           "expected as many new yield values as new iter operands");
-    yield.getResultsMutable().append(newYieldedValues);
-  }
-
-  // Remap the BlockArguments from the original loop to the new loop
-  // BlockArguments.
-  ArrayRef<BlockArgument> bbArgs = loopBody->getArguments();
-  for (auto it :
-       llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size())))
-    std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
-
-  if (replaceIterOperandsUsesInLoop) {
-    // Replace all uses of `newIterOperands` with the corresponding basic block
-    // arguments.
-    for (auto it : llvm::zip(newIterOperands, newBBArgs)) {
-      std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) {
-        Operation *user = use.getOwner();
-        return newLoop->isProperAncestor(user);
-      });
-    }
-  }
-
-  // Replace all uses of the original loop with corresponding values from the
-  // new loop.
-  loop.replaceAllUsesWith(
-      newLoop.getResults().take_front(loop.getNumResults()));
-
-  // Add a fake yield to the original loop body that just returns the
-  // BlockArguments corresponding to the iter_args. This makes it a no-op loop.
-  // The loop is dead. The caller is expected to erase it.
-  builder.setInsertionPointToEnd(loopBody);
-  builder.create<scf::YieldOp>(loop->getLoc(), loop.getRegionIterArgs());
-
-  return newLoop;
-}
-
 SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
-    OpBuilder &builder, ArrayRef<scf::ForOp> loopNest,
-    ValueRange newIterOperands, const NewYieldValueFn &newYieldValueFn,
+    RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
+    ValueRange newIterOperands, const NewYieldValuesFn &newYieldValuesFn,
     bool replaceIterOperandsUsesInLoop) {
   if (loopNest.empty())
     return {};
@@ -146,31 +78,32 @@ SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
   // }
   // ```
   //
-  // The inner most loop is handled using the `replaceLoopWithNewYields`
+  // The inner most loop is handled using the `replaceWithAdditionalYields`
   // that works on a single loop.
   if (loopNest.size() == 1) {
-    auto innerMostLoop = replaceLoopWithNewYields(
-        builder, loopNest.back(), newIterOperands, newYieldValueFn,
-        replaceIterOperandsUsesInLoop);
+    auto innerMostLoop =
+        cast<scf::ForOp>(*loopNest.back().replaceWithAdditionalYields(
+            rewriter, newIterOperands, replaceIterOperandsUsesInLoop,
+            newYieldValuesFn));
     return {innerMostLoop};
   }
   // The outer loops are modified by calling this method recursively
   // - The return value of the inner loop is the value yielded by this loop.
   // - The region iter args of this loop are the init_args for the inner loop.
   SmallVector<scf::ForOp> newLoopNest;
-  NewYieldValueFn fn =
+  NewYieldValuesFn fn =
       [&](OpBuilder &innerBuilder, Location loc,
           ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
-    newLoopNest = replaceLoopNestWithNewYields(builder, loopNest.drop_front(),
-                                               innerNewBBArgs, newYieldValueFn,
+    newLoopNest = replaceLoopNestWithNewYields(rewriter, loopNest.drop_front(),
+                                               innerNewBBArgs, newYieldValuesFn,
                                                replaceIterOperandsUsesInLoop);
     return llvm::to_vector(llvm::map_range(
         newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
         [](OpResult r) -> Value { return r; }));
   };
   scf::ForOp outerMostLoop =
-      replaceLoopWithNewYields(builder, loopNest.front(), newIterOperands, fn,
-                               replaceIterOperandsUsesInLoop);
+      cast<scf::ForOp>(*loopNest.front().replaceWithAdditionalYields(
+          rewriter, newIterOperands, replaceIterOperandsUsesInLoop, fn));
   newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
   return newLoopNest;
 }
diff --git a/mlir/test/Transforms/scf-replace-with-new-yields.mlir b/mlir/test/Transforms/scf-replace-with-new-yields.mlir
index 802f86e6806b318..d7af15f52e72a9b 100644
--- a/mlir/test/Transforms/scf-replace-with-new-yields.mlir
+++ b/mlir/test/Transforms/scf-replace-with-new-yields.mlir
@@ -15,7 +15,4 @@ func.func @doubleup(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f3
 //       CHECK:     %[[DOUBLE:.+]] = arith.addf %[[INIT1]], %[[INIT1]]
 //       CHECK:     %[[DOUBLE2:.+]] = arith.addf %[[DOUBLE]], %[[DOUBLE]]
 //       CHECK:     scf.yield %[[DOUBLE]], %[[DOUBLE2]]
-//       CHECK:   %[[OLDLOOP:.+]] = scf.for
-//  CHECK-SAME:       iter_args(%[[INIT:.+]] = %[[ARG]])
-//       CHECK:     scf.yield %[[INIT]]
 //       CHECK:   return %[[NEWLOOP]]#0
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 455c9234b8c93de..1d40615305c02c4 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -50,19 +50,23 @@ struct TestSCFForUtilsPass
         auto newInitValues = forOp.getInitArgs();
         if (newInitValues.empty())
           return;
-        NewYieldValueFn fn = [&](OpBuilder &b, Location loc,
-                                 ArrayRef<BlockArgument> newBBArgs) {
-          Block *block = newBBArgs.front().getOwner();
+        auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
+        SmallVector<Value> oldYieldValues(yieldOp.getResults().begin(),
+                                          yieldOp.getResults().end());
+        NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
+                                  ArrayRef<BlockArgument> newBBArgs) {
           SmallVector<Value> newYieldValues;
-          for (auto yieldVal :
-               cast<scf::YieldOp>(block->getTerminator()).getResults()) {
+          for (auto yieldVal : oldYieldValues) {
             newYieldValues.push_back(
                 b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
           }
           return newYieldValues;
         };
-        OpBuilder b(forOp);
-        replaceLoopWithNewYields(b, forOp, newInitValues, fn);
+        IRRewriter rewriter(forOp.getContext());
+        if (failed(forOp.replaceWithAdditionalYields(
+                rewriter, newInitValues, /*replaceInitOperandUsesInLoop=*/true,
+                fn)))
+          signalPassFailure();
       });
     }
   }

>From 5a05ece8354489da36020d2e668504fcb8139e93 Mon Sep 17 00:00:00 2001
From: Matthias Springer <me at m-sp.org>
Date: Mon, 25 Sep 2023 11:34:29 +0200
Subject: [PATCH 2/2] [mlir][Interfaces] `LoopLikeOpInterface`: Add helper to
 get yielded values

Add a new interface method that returns the yielded values.

Also add a verifier that checks the number of inits/iter_args/yielded values. Most of the checked invariants (but not all of them) are already covered by the `RegionBranchOpInterface`, but the `LoopLikeOpInterface` now provides (additional) error messages that are easier to read.
---
 .../mlir/Dialect/Affine/IR/AffineOps.td       |  3 +-
 mlir/include/mlir/Dialect/SCF/IR/SCFOps.td    | 11 +++++-
 .../mlir/Interfaces/LoopLikeInterface.h       |  5 +++
 .../mlir/Interfaces/LoopLikeInterface.td      | 17 +++++++++
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |  4 ++
 .../Linalg/Transforms/HoistPadding.cpp        |  5 +--
 mlir/lib/Dialect/SCF/IR/SCF.cpp               | 27 ++++++++------
 .../BufferizableOpInterfaceImpl.cpp           |  3 +-
 .../SCF/Transforms/LoopCanonicalization.cpp   |  5 +--
 mlir/lib/Interfaces/LoopLikeInterface.cpp     | 37 +++++++++++++++++++
 mlir/test/Dialect/SCF/invalid.mlir            | 30 ++++++++++++++-
 mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp    |  5 +--
 12 files changed, 124 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
index d8ef0506d0822d7..6eb0a4468fd91b5 100644
--- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
+++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
@@ -120,7 +120,8 @@ def AffineForOp : Affine_Op<"for",
     [AutomaticAllocationScope, ImplicitAffineTerminator, ConditionallySpeculatable,
      RecursiveMemoryEffects, DeclareOpInterfaceMethods<LoopLikeOpInterface,
      ["getSingleInductionVar", "getSingleLowerBound", "getSingleStep",
-      "getSingleUpperBound", "replaceWithAdditionalYields"]>,
+      "getSingleUpperBound", "getYieldedValues",
+      "replaceWithAdditionalYields"]>,
      DeclareOpInterfaceMethods<RegionBranchOpInterface,
      ["getEntrySuccessorOperands"]>]> {
   let summary = "for operation";
diff --git a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
index e62015b84888a18..c63703327f5bb34 100644
--- a/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
+++ b/mlir/include/mlir/Dialect/SCF/IR/SCFOps.td
@@ -121,8 +121,8 @@ def ExecuteRegionOp : SCF_Op<"execute_region", [
 def ForOp : SCF_Op<"for",
       [AutomaticAllocationScope, DeclareOpInterfaceMethods<LoopLikeOpInterface,
        ["getInits", "getSingleInductionVar", "getSingleLowerBound",
-        "getSingleStep", "getSingleUpperBound", "promoteIfSingleIteration",
-        "replaceWithAdditionalYields"]>,
+        "getSingleStep", "getSingleUpperBound", "getYieldedValues",
+        "promoteIfSingleIteration", "replaceWithAdditionalYields"]>,
        AllTypesMatch<["lowerBound", "upperBound", "step"]>,
        ConditionallySpeculatable,
        DeclareOpInterfaceMethods<RegionBranchOpInterface,
@@ -242,9 +242,11 @@ def ForOp : SCF_Op<"for",
         function_ref<void(OpBuilder &, Location, Value, ValueRange)>;
 
     Value getInductionVar() { return getBody()->getArgument(0); }
+
     Block::BlockArgListType getRegionIterArgs() {
       return getBody()->getArguments().drop_front(getNumInductionVars());
     }
+
     /// Return the `index`-th region iteration argument.
     BlockArgument getRegionIterArg(unsigned index) {
       assert(index < getNumRegionIterArgs() &&
@@ -1081,6 +1083,11 @@ def WhileOp : SCF_Op<"while",
 
     ConditionOp getConditionOp();
     YieldOp getYieldOp();
+
+    /// Return the values that are yielded from the "before" region (by the
+    /// ConditionOp).
+    ValueRange getYieldedValues();
+
     Block::BlockArgListType getBeforeArguments();
     Block::BlockArgListType getAfterArguments();
     Block *getBeforeBody() { return &getBefore().front(); }
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.h b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
index 0eebb984e5897ae..7c7d378d0590ab1 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.h
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.h
@@ -24,6 +24,11 @@ class RewriterBase;
 /// arguments in `newBbArgs`.
 using NewYieldValuesFn = std::function<SmallVector<Value>(
     OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs)>;
+
+namespace detail {
+/// Verify invariants of the LoopLikeOpInterface.
+LogicalResult verifyLoopLikeOpInterface(Operation *op);
+} // namespace detail
 } // namespace mlir
 
 /// Include the generated interface declarations.
diff --git a/mlir/include/mlir/Interfaces/LoopLikeInterface.td b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
index ded0a29292ff6f0..23049669ec9348f 100644
--- a/mlir/include/mlir/Interfaces/LoopLikeInterface.td
+++ b/mlir/include/mlir/Interfaces/LoopLikeInterface.td
@@ -141,6 +141,17 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
         return ::mlir::Block::BlockArgListType();
       }]
     >,
+    InterfaceMethod<[{
+        TODO
+      }],
+      /*retTy=*/"::mlir::ValueRange",
+      /*methodName=*/"getYieldedValues",
+      /*args=*/(ins),
+      /*methodBody=*/"",
+      /*defaultImplementation=*/[{
+        return ::mlir::ValueRange();
+      }]
+    >,
     InterfaceMethod<[{
         Append the specified additional "init" operands: replace this loop with
         a new loop that has the additional init operands. The loop body of
@@ -192,6 +203,12 @@ def LoopLikeOpInterface : OpInterface<"LoopLikeOpInterface"> {
           });
     }
   }];
+
+  let verifyWithRegions = 1;
+
+  let verify = [{
+    return detail::verifyLoopLikeOpInterface($_op);
+  }];
 }
 
 #endif // MLIR_INTERFACES_LOOPLIKEINTERFACE
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index 6c060c90e24af82..4109167de44e860 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -2193,6 +2193,10 @@ unsigned AffineForOp::getNumIterOperands() {
   return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
 }
 
+ValueRange AffineForOp::getYieldedValues() {
+  return cast<AffineYieldOp>(getBody()->getTerminator()).getOperands();
+}
+
 void AffineForOp::print(OpAsmPrinter &p) {
   p << ' ';
   p.printRegionArgument(getBody()->getArgument(0), /*argAttrs=*/{},
diff --git a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
index 72bd2b409f5d52b..8fef99bb375095a 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/HoistPadding.cpp
@@ -811,8 +811,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
   rewriter.setInsertionPointAfter(hoistedPackedTensor.getDefiningOp());
 
   unsigned iterArgNumber = forOp.getResultForOpOperand(*pUse).getResultNumber();
-  auto yieldOp = cast<scf::YieldOp>(forOp.getBody(0)->getTerminator());
-  auto yieldingExtractSliceOp = yieldOp->getOperand(iterArgNumber)
+  auto yieldingExtractSliceOp = forOp.getYieldedValues()[iterArgNumber]
                                     .getDefiningOp<tensor::ExtractSliceOp>();
   if (!yieldingExtractSliceOp)
     return tensor::ExtractSliceOp();
@@ -826,7 +825,7 @@ padThroughLoopIterArg(RewriterBase &rewriter, Value paddedValueBeforeHoisting,
 
   SmallVector<Value> initArgs = forOp.getInitArgs();
   initArgs[iterArgNumber] = hoistedPackedTensor;
-  SmallVector<Value> yieldOperands = yieldOp.getOperands();
+  SmallVector<Value> yieldOperands = llvm::to_vector(forOp.getYieldedValues());
   yieldOperands[iterArgNumber] = yieldingExtractSliceOp.getSource();
 
   int64_t numOriginalForOpResults = initArgs.size();
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index 8d8481421e18d57..508227d6e7ce4b0 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -400,7 +400,7 @@ LogicalResult ForOp::promoteIfSingleIteration(RewriterBase &rewriter) {
 
   // Replace all results with the yielded values.
   auto yieldOp = cast<scf::YieldOp>(getBody()->getTerminator());
-  rewriter.replaceAllUsesWith(getResults(), yieldOp.getOperands());
+  rewriter.replaceAllUsesWith(getResults(), getYieldedValues());
 
   // Replace block arguments with lower bound (replacement for IV) and
   // iter_args.
@@ -772,27 +772,26 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
   LogicalResult matchAndRewrite(scf::ForOp forOp,
                                 PatternRewriter &rewriter) const final {
     bool canonicalize = false;
-    Block &block = forOp.getRegion().front();
-    auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
 
     // An internal flat vector of block transfer
     // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
     // transformed block argument mappings. This plays the role of a
     // IRMapping for the particular use case of calling into
     // `inlineBlockBefore`.
+    int64_t numResults = forOp.getNumResults();
     SmallVector<bool, 4> keepMask;
-    keepMask.reserve(yieldOp.getNumOperands());
+    keepMask.reserve(numResults);
     SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
         newResultValues;
-    newBlockTransferArgs.reserve(1 + forOp.getInitArgs().size());
+    newBlockTransferArgs.reserve(1 + numResults);
     newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
     newIterArgs.reserve(forOp.getInitArgs().size());
-    newYieldValues.reserve(yieldOp.getNumOperands());
-    newResultValues.reserve(forOp.getNumResults());
+    newYieldValues.reserve(numResults);
+    newResultValues.reserve(numResults);
     for (auto it : llvm::zip(forOp.getInitArgs(),       // iter from outside
                              forOp.getRegionIterArgs(), // iter inside region
                              forOp.getResults(),        // op results
-                             yieldOp.getOperands()      // iter yield
+                             forOp.getYieldedValues()   // iter yield
                              )) {
       // Forwarded is `true` when:
       // 1) The region `iter` argument is yielded.
@@ -946,12 +945,10 @@ struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
       return failure();
     // If the loop is empty, iterates at least once, and only returns values
     // defined outside of the loop, remove it and replace it with yield values.
-    auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
-    auto yieldOperands = yieldOp.getOperands();
-    if (llvm::any_of(yieldOperands,
+    if (llvm::any_of(op.getYieldedValues(),
                      [&](Value v) { return !op.isDefinedOutsideOfLoop(v); }))
       return failure();
-    rewriter.replaceOp(op, yieldOperands);
+    rewriter.replaceOp(op, op.getYieldedValues());
     return success();
   }
 };
@@ -1224,6 +1221,10 @@ std::optional<APInt> ForOp::getConstantStep() {
   return {};
 }
 
+ValueRange ForOp::getYieldedValues() {
+  return cast<scf::YieldOp>(getBody()->getTerminator()).getResults();
+}
+
 Speculation::Speculatability ForOp::getSpeculatability() {
   // `scf.for (I = Start; I < End; I += 1)` terminates for all values of Start
   // and End.
@@ -3205,6 +3206,8 @@ YieldOp WhileOp::getYieldOp() {
   return cast<YieldOp>(getAfterBody()->getTerminator());
 }
 
+ValueRange WhileOp::getYieldedValues() { return getYieldOp().getResults(); }
+
 Block::BlockArgListType WhileOp::getBeforeArguments() {
   return getBeforeBody()->getArguments();
 }
diff --git a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
index bcbc693a9742ccc..22c893457c6f59e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -489,9 +489,8 @@ struct ForOpInterface
     auto forOp = cast<scf::ForOp>(op);
     OpOperand &forOperand = forOp.getOpOperandForResult(opResult);
     auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
-    auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
     bool equivalentYield = state.areEquivalentBufferizedValues(
-        bbArg, yieldOp->getOperand(opResult.getResultNumber()));
+        bbArg, forOp.getYieldedValues()[opResult.getResultNumber()]);
     return equivalentYield ? BufferRelation::Equivalent
                            : BufferRelation::Unknown;
   }
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
index 0cd19fbefa8ef98..43e79d309c66784 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopCanonicalization.cpp
@@ -36,10 +36,9 @@ using namespace mlir::scf;
 /// type of the corresponding basic block argument of the loop.
 /// Note: This function handles only simple cases. Expand as needed.
 static bool isShapePreserving(ForOp forOp, int64_t arg) {
-  auto yieldOp = cast<YieldOp>(forOp.getBody()->getTerminator());
-  assert(arg < static_cast<int64_t>(yieldOp.getResults().size()) &&
+  assert(arg < static_cast<int64_t>(forOp.getNumResults()) &&
          "arg is out of bounds");
-  Value value = yieldOp.getResults()[arg];
+  Value value = forOp.getYieldedValues()[arg];
   while (value) {
     if (value == forOp.getRegionIterArgs()[arg])
       return true;
diff --git a/mlir/lib/Interfaces/LoopLikeInterface.cpp b/mlir/lib/Interfaces/LoopLikeInterface.cpp
index 781a21bb3ecd3a0..15a816f4e448839 100644
--- a/mlir/lib/Interfaces/LoopLikeInterface.cpp
+++ b/mlir/lib/Interfaces/LoopLikeInterface.cpp
@@ -52,3 +52,40 @@ bool LoopLikeOpInterface::blockIsInLoop(Block *block) {
   }
   return false;
 }
+
+LogicalResult detail::verifyLoopLikeOpInterface(Operation *op) {
+  // Note: These invariants are also verified by the RegionBranchOpInterface,
+  // but the LoopLikeOpInterface provides better error messages.
+  auto loopLikeOp = cast<LoopLikeOpInterface>(op);
+
+  // Verify number of inits/iter_args/yielded values.
+  if (loopLikeOp.getInits().size() != loopLikeOp.getRegionIterArgs().size())
+    return op->emitOpError("different number of inits and region iter_args: ")
+           << loopLikeOp.getInits().size()
+           << " != " << loopLikeOp.getRegionIterArgs().size();
+  if (loopLikeOp.getRegionIterArgs().size() !=
+      loopLikeOp.getYieldedValues().size())
+    return op->emitOpError(
+               "different number of region iter_args and yielded values: ")
+           << loopLikeOp.getRegionIterArgs().size()
+           << " != " << loopLikeOp.getYieldedValues().size();
+
+  // Verify types of inits/iter_args/yielded values.
+  int64_t i = 0;
+  for (const auto it :
+       llvm::zip_equal(loopLikeOp.getInits(), loopLikeOp.getRegionIterArgs(),
+                       loopLikeOp.getYieldedValues())) {
+    if (std::get<0>(it).getType() != std::get<1>(it).getType())
+      op->emitOpError(std::to_string(i))
+          << "-th init and " << i << "-th region iter_arg have different type: "
+          << std::get<0>(it).getType() << " != " << std::get<1>(it).getType();
+    if (std::get<1>(it).getType() != std::get<2>(it).getType())
+      op->emitOpError(std::to_string(i))
+          << "-th region iter_arg and " << i
+          << "-th yielded value have different type: "
+          << std::get<1>(it).getType() << " != " << std::get<2>(it).getType();
+    ++i;
+  }
+
+  return success();
+}
diff --git a/mlir/test/Dialect/SCF/invalid.mlir b/mlir/test/Dialect/SCF/invalid.mlir
index f6044ad10829227..1b2c3f563195c52 100644
--- a/mlir/test/Dialect/SCF/invalid.mlir
+++ b/mlir/test/Dialect/SCF/invalid.mlir
@@ -96,6 +96,32 @@ func.func @not_enough_loop_results(%arg0: index, %init: f32) {
 
 // -----
 
+func.func @too_many_iter_args(%arg0: index, %init: f32) {
+  // expected-error @below{{different number of inits and region iter_args: 1 != 2}}
+  %x = "scf.for"(%arg0, %arg0, %arg0, %init) (
+    {
+    ^bb0(%i0 : index, %iter: f32, %iter2: f32):
+      scf.yield %iter, %iter : f32, f32
+    }
+  ) : (index, index, index, f32) -> (f32)
+  return
+}
+
+// -----
+
+func.func @too_few_yielded_values(%arg0: index, %init: f32) {
+  // expected-error @below{{different number of region iter_args and yielded values: 2 != 1}}
+  %x, %x2 = "scf.for"(%arg0, %arg0, %arg0, %init, %init) (
+    {
+    ^bb0(%i0 : index, %iter: f32, %iter2: f32):
+      scf.yield %iter : f32
+    }
+  ) : (index, index, index, f32, f32) -> (f32, f32)
+  return
+}
+
+// -----
+
 func.func @loop_if_not_i1(%arg0: index) {
   // expected-error at +1 {{operand #0 must be 1-bit signless integer}}
   "scf.if"(%arg0) ({}, {}) : (index) -> ()
@@ -422,7 +448,8 @@ func.func @std_for_operands_mismatch_3(%arg0 : index, %arg1 : index, %arg2 : ind
 func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : index) {
   %s0 = arith.constant 0.0 : f32
   %t0 = arith.constant 1.0 : f32
-  // expected-error @+1 {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
+  // expected-error @below {{1-th region iter_arg and 1-th yielded value have different type: 'f32' != 'i32'}}
+  // expected-error @below {{along control flow edge from Region #0 to Region #0: source type #1 'i32' should match input type #1 'f32'}}
   %result1:2 = scf.for %i0 = %arg0 to %arg1 step %arg2
                     iter_args(%si = %s0, %ti = %t0) -> (f32, f32) {
     %sn = arith.addf %si, %si : f32
@@ -432,7 +459,6 @@ func.func @std_for_operands_mismatch_4(%arg0 : index, %arg1 : index, %arg2 : ind
   return
 }
 
-
 // -----
 
 func.func @parallel_invalid_yield(
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index 1d40615305c02c4..565d07669792f1b 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -50,9 +50,8 @@ struct TestSCFForUtilsPass
         auto newInitValues = forOp.getInitArgs();
         if (newInitValues.empty())
           return;
-        auto yieldOp = cast<scf::YieldOp>(forOp.getBody()->getTerminator());
-        SmallVector<Value> oldYieldValues(yieldOp.getResults().begin(),
-                                          yieldOp.getResults().end());
+        SmallVector<Value> oldYieldValues =
+            llvm::to_vector(forOp.getYieldedValues());
         NewYieldValuesFn fn = [&](OpBuilder &b, Location loc,
                                   ArrayRef<BlockArgument> newBBArgs) {
           SmallVector<Value> newYieldValues;



More information about the Mlir-commits mailing list