[Mlir-commits] [mlir] 567fd52 - [mlir][SCF] Add utility method to add new yield values to a loop.

Mahesh Ravishankar llvmlistbot at llvm.org
Tue May 10 11:44:31 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-05-10T18:44:11Z
New Revision: 567fd523bf538523f58779e5af9d20c3e48838a2

URL: https://github.com/llvm/llvm-project/commit/567fd523bf538523f58779e5af9d20c3e48838a2
DIFF: https://github.com/llvm/llvm-project/commit/567fd523bf538523f58779e5af9d20c3e48838a2.diff

LOG: [mlir][SCF] Add utility method to add new yield values to a loop.

The current implementation of `cloneWithNewYields` has a few issues
- It clones the loop body of the original loop to create a new
  loop. This is very expensive.
- It performs `erase` operations which are incompatible when this
  method is called from within a pattern rewrite. All erases need to
  go through `PatternRewriter`.

To address these a new utility method `replaceLoopWithNewYields` is added
which
- moves the operations from the original loop into the new loop.
- replaces all uses of the original loop with the corresponding
  results of the new loop
- use a call back to allow caller to generate the new yield values.
- the original loop is modified to just yield the basic block
  arguments corresponding to the iter_args of the loop. This
  represents a no-op loop. The loop itself is dead (since all its uses
  are replaced), but is not removed. The caller is expected to erase
  the op. Consequently, this method can be called from within a
  `matchAndRewrite` method of a `PatternRewriter`.

The `cloneWithNewYields` could be replaces with
`replaceLoopWithNewYields`, but that seems to trigger a failure during
walks, potentially due to the operations being moved. That is left as
a TODO.

Differential Revision: https://reviews.llvm.org/D125147

Added: 
    mlir/test/Transforms/scf-replace-with-new-yields.mlir

Modified: 
    mlir/include/mlir/Dialect/SCF/Utils/Utils.h
    mlir/lib/Dialect/SCF/Utils/Utils.cpp
    mlir/test/Transforms/scf-loop-utils.mlir
    mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
index bb5a4848dad1e..70fc0842d8fc6 100644
--- a/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
+++ b/mlir/include/mlir/Dialect/SCF/Utils/Utils.h
@@ -61,6 +61,28 @@ scf::ForOp cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
                               ValueRange newYieldedValues,
                               bool replaceLoopResults = true);
 
+/// Replace the `loop` with `newIterOperands` added as new initialization
+/// values. `newYieldValuesFn` is a callback that can be used to specify
+/// the additional values to be yielded by the loop. The number of
+/// values returned by the callback should match the number of new
+/// initialization values. This function
+/// - Moves (i.e. doesnt clone) operations from the `loop` to the newly created
+///   loop
+/// - Replaces the uses of `loop` with the new loop.
+/// - `loop` isnt erased, but is left in a "no-op" state where the body of the
+///   loop just yields the basic block arguments that correspond to the
+///   initialization values of a loop. The loop is dead after this method.
+/// - All uses of the `newIterOperands` within the generated new loop
+///   are replaced with the corresponding `BlockArgument` in the loop body.
+/// TODO: This method could be used instead of `cloneWithNewYields`. Making
+/// this change though hits assertions in the walk mechanism that is unrelated
+/// to this method itself.
+using NewYieldValueFn = std::function<SmallVector<Value>(
+    OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBBArgs)>;
+scf::ForOp replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
+                                    ValueRange newIterOperands,
+                                    NewYieldValueFn newYieldValuesFn);
+
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the
 /// single block. This constraint makes it easy to determine the result.

diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 5d69a3fa7260b..a2a751a3881fc 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -91,6 +91,70 @@ scf::ForOp mlir::cloneWithNewYields(OpBuilder &b, scf::ForOp loop,
   return newLoop;
 }
 
+scf::ForOp mlir::replaceLoopWithNewYields(OpBuilder &builder, scf::ForOp loop,
+                                          ValueRange newIterOperands,
+                                          NewYieldValueFn newYieldValuesFn) {
+  // Create a new loop before the existing one, with the extra operands.
+  OpBuilder::InsertionGuard g(builder);
+  builder.setInsertionPoint(loop);
+  auto operands = llvm::to_vector(loop.getIterOperands());
+  operands.append(newIterOperands.begin(), newIterOperands.end());
+  scf::ForOp newLoop = builder.create<scf::ForOp>(
+      loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
+      operands, [](OpBuilder &, Location, Value, ValueRange) {});
+
+  Block *loopBody = loop.getBody();
+  Block *newLoopBody = newLoop.getBody();
+
+  // Move the body of the original loop to the new loop.
+  newLoopBody->getOperations().splice(newLoopBody->end(),
+                                      loopBody->getOperations());
+
+  // Generate the new yield values to use by using the callback and ppend the
+  // yield values to the scf.yield operation.
+  auto yield = cast<scf::YieldOp>(newLoopBody->getTerminator());
+  ArrayRef<BlockArgument> newBBArgs =
+      newLoopBody->getArguments().take_back(newIterOperands.size());
+  {
+    OpBuilder::InsertionGuard g(builder);
+    builder.setInsertionPoint(yield);
+    SmallVector<Value> newYieldedValues =
+        newYieldValuesFn(builder, loop.getLoc(), newBBArgs);
+    assert(newIterOperands.size() == newYieldedValues.size() &&
+           "expected as many new yield values as new iter operands");
+    yield.getResultsMutable().append(newYieldedValues);
+  }
+
+  // Remap the BlockArguments from the original loop to the new loop
+  // BlockArguments.
+  ArrayRef<BlockArgument> bbArgs = loopBody->getArguments();
+  for (auto it :
+       llvm::zip(bbArgs, newLoopBody->getArguments().take_front(bbArgs.size())))
+    std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
+
+  // Replace all uses of `newIterOperands` with the corresponding basic block
+  // arguments.
+  for (auto it : llvm::zip(newIterOperands, newBBArgs)) {
+    std::get<0>(it).replaceUsesWithIf(std::get<1>(it), [&](OpOperand &use) {
+      Operation *user = use.getOwner();
+      return newLoop->isProperAncestor(user);
+    });
+  }
+
+  // Replace all uses of the original loop with corresponding values from the
+  // new loop.
+  loop.replaceAllUsesWith(
+      newLoop.getResults().take_front(loop.getNumResults()));
+
+  // Add a fake yield to the original loop body that just returns the
+  // BlockArguments corresponding to the iter_args. This makes it a no-op loop.
+  // The loop is dead. The caller is expected to erase it.
+  builder.setInsertionPointToEnd(loopBody);
+  builder.create<scf::YieldOp>(loop->getLoc(), loop.getRegionIterArgs());
+
+  return newLoop;
+}
+
 /// Outline a region with a single block into a new FuncOp.
 /// Assumes the FuncOp result types is the type of the yielded operands of the
 /// single block. This constraint makes it easy to determine the result.

diff  --git a/mlir/test/Transforms/scf-loop-utils.mlir b/mlir/test/Transforms/scf-loop-utils.mlir
index 32e42b9da9970..3e03d3a9044d7 100644
--- a/mlir/test/Transforms/scf-loop-utils.mlir
+++ b/mlir/test/Transforms/scf-loop-utils.mlir
@@ -1,4 +1,4 @@
-// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils -mlir-disable-threading %s | FileCheck %s
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils=test-clone-with-new-yields -mlir-disable-threading %s | FileCheck %s
 
 // CHECK-LABEL: @hoist
 //  CHECK-SAME: %[[lb:[a-zA-Z0-9]*]]: index,

diff  --git a/mlir/test/Transforms/scf-replace-with-new-yields.mlir b/mlir/test/Transforms/scf-replace-with-new-yields.mlir
new file mode 100644
index 0000000000000..802f86e6806b3
--- /dev/null
+++ b/mlir/test/Transforms/scf-replace-with-new-yields.mlir
@@ -0,0 +1,21 @@
+
+// RUN: mlir-opt -allow-unregistered-dialect -test-scf-for-utils=test-replace-with-new-yields -mlir-disable-threading %s | FileCheck %s
+
+func.func @doubleup(%lb: index, %ub: index, %step: index, %extra_arg: f32) -> f32 {
+  %0 = scf.for %i = %lb to %ub step %step iter_args(%iter = %extra_arg) -> (f32) {
+    %1 = arith.addf %iter, %iter : f32
+    scf.yield %1: f32
+  }
+  return %0: f32
+}
+// CHECK-LABEL: func @doubleup
+//  CHECK-SAME:   %[[ARG:[a-zA-Z0-9]+]]: f32
+//       CHECK:   %[[NEWLOOP:.+]]:2 = scf.for
+//  CHECK-SAME:       iter_args(%[[INIT1:.+]] = %[[ARG]], %[[INIT2:.+]] = %[[ARG]]
+//       CHECK:     %[[DOUBLE:.+]] = arith.addf %[[INIT1]], %[[INIT1]]
+//       CHECK:     %[[DOUBLE2:.+]] = arith.addf %[[DOUBLE]], %[[DOUBLE]]
+//       CHECK:     scf.yield %[[DOUBLE]], %[[DOUBLE2]]
+//       CHECK:   %[[OLDLOOP:.+]] = scf.for
+//  CHECK-SAME:       iter_args(%[[INIT:.+]] = %[[ARG]])
+//       CHECK:     scf.yield %[[INIT]]
+//       CHECK:   return %[[NEWLOOP]]#0

diff  --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index f354a309cd6e6..858b0b1f0d42e 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -32,29 +32,67 @@ struct TestSCFForUtilsPass
   StringRef getArgument() const final { return "test-scf-for-utils"; }
   StringRef getDescription() const final { return "test scf.for utils"; }
   explicit TestSCFForUtilsPass() = default;
+  TestSCFForUtilsPass(const TestSCFForUtilsPass &pass) : PassWrapper(pass) {}
+
+  Option<bool> testCloneWithNewYields{
+      *this, "test-clone-with-new-yields",
+      llvm::cl::desc(
+          "Test cloning of a loop while returning additional yield values"),
+      llvm::cl::init(false)};
+
+  Option<bool> testReplaceWithNewYields{
+      *this, "test-replace-with-new-yields",
+      llvm::cl::desc("Test replacing a loop with a new loop that returns new "
+                     "additional yeild values"),
+      llvm::cl::init(false)};
 
   void runOnOperation() override {
     func::FuncOp func = getOperation();
     SmallVector<scf::ForOp, 4> toErase;
 
-    func.walk([&](Operation *fakeRead) {
-      if (fakeRead->getName().getStringRef() != "fake_read")
-        return;
-      auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
-      auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
-      auto loop = fakeRead->getParentOfType<scf::ForOp>();
-
-      OpBuilder b(loop);
-      loop.moveOutOfLoop(fakeRead);
-      fakeWrite->moveAfter(loop);
-      auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
-                                        fakeCompute->getResult(0));
-      fakeCompute->getResult(0).replaceAllUsesWith(
-          newLoop.getResults().take_back()[0]);
-      toErase.push_back(loop);
-    });
-    for (auto loop : llvm::reverse(toErase))
-      loop.erase();
+    if (testCloneWithNewYields) {
+      func.walk([&](Operation *fakeRead) {
+        if (fakeRead->getName().getStringRef() != "fake_read")
+          return;
+        auto *fakeCompute = fakeRead->getResult(0).use_begin()->getOwner();
+        auto *fakeWrite = fakeCompute->getResult(0).use_begin()->getOwner();
+        auto loop = fakeRead->getParentOfType<scf::ForOp>();
+
+        OpBuilder b(loop);
+        loop.moveOutOfLoop(fakeRead);
+        fakeWrite->moveAfter(loop);
+        auto newLoop = cloneWithNewYields(b, loop, fakeRead->getResult(0),
+                                          fakeCompute->getResult(0));
+        fakeCompute->getResult(0).replaceAllUsesWith(
+            newLoop.getResults().take_back()[0]);
+        toErase.push_back(loop);
+      });
+      for (auto loop : llvm::reverse(toErase))
+        loop.erase();
+    }
+
+    if (testReplaceWithNewYields) {
+      func.walk([&](scf::ForOp forOp) {
+        if (forOp.getNumResults() == 0)
+          return;
+        auto newInitValues = forOp.getInitArgs();
+        if (newInitValues.empty())
+          return;
+        NewYieldValueFn fn = [&](OpBuilder &b, Location loc,
+                                 ArrayRef<BlockArgument> newBBArgs) {
+          Block *block = newBBArgs.front().getOwner();
+          SmallVector<Value> newYieldValues;
+          for (auto yieldVal :
+               cast<scf::YieldOp>(block->getTerminator()).getResults()) {
+            newYieldValues.push_back(
+                b.create<arith::AddFOp>(loc, yieldVal, yieldVal));
+          }
+          return newYieldValues;
+        };
+        OpBuilder b(forOp);
+        replaceLoopWithNewYields(b, forOp, newInitValues, fn);
+      });
+    }
   }
 };
 
@@ -88,7 +126,8 @@ static const StringLiteral kTestPipeliningLoopMarker =
     "__test_pipelining_loop__";
 static const StringLiteral kTestPipeliningStageMarker =
     "__test_pipelining_stage__";
-/// Marker to express the order in which operations should be after pipelining.
+/// Marker to express the order in which operations should be after
+/// pipelining.
 static const StringLiteral kTestPipeliningOpOrderMarker =
     "__test_pipelining_op_order__";
 


        


More information about the Mlir-commits mailing list