[Mlir-commits] [mlir] 0d5cb90 - [mlir][scf] Simplify the logic for `replaceLoopWithNewYields` for perfectly nested loops.

Mahesh Ravishankar llvmlistbot at llvm.org
Thu Sep 29 09:52:22 PDT 2022


Author: Mahesh Ravishankar
Date: 2022-09-29T16:52:02Z
New Revision: 0d5cb90f6c44730a74f49feb6f5b624b3414e459

URL: https://github.com/llvm/llvm-project/commit/0d5cb90f6c44730a74f49feb6f5b624b3414e459
DIFF: https://github.com/llvm/llvm-project/commit/0d5cb90f6c44730a74f49feb6f5b624b3414e459.diff

LOG: [mlir][scf] Simplify the logic for `replaceLoopWithNewYields` for perfectly nested loops.

Based on discussion in https://reviews.llvm.org/D134411, instead of
first modifying the inner most loop first followed by modifying the
outer loops from inside out, this patch restructures the logic to
start the modification from the outer most loop.

Differential Revision: https://reviews.llvm.org/D134832

Added: 
    

Modified: 
    mlir/lib/Dialect/SCF/Utils/Utils.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index c6e416bda6df9..e99510e66d3f1 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -111,56 +111,66 @@ SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
     bool replaceIterOperandsUsesInLoop) {
   if (loopNest.empty())
     return {};
-  SmallVector<scf::ForOp> newLoopNest(loopNest.size());
-
-  newLoopNest.back() = replaceLoopWithNewYields(
-      builder, loopNest.back(), newIterOperands, newYieldValueFn);
-
-  for (unsigned loopDepth :
-       llvm::reverse(llvm::seq<unsigned>(0, loopNest.size() - 1))) {
-    NewYieldValueFn fn = [&](OpBuilder &innerBuilder, Location loc,
-                             ArrayRef<BlockArgument> innerNewBBArgs) {
-      SmallVector<Value> newYields(
-          newLoopNest[loopDepth + 1]->getResults().take_back(
-              newIterOperands.size()));
-      return newYields;
-    };
-    newLoopNest[loopDepth] =
-        replaceLoopWithNewYields(builder, loopNest[loopDepth], newIterOperands,
-                                 fn, replaceIterOperandsUsesInLoop);
-    if (!replaceIterOperandsUsesInLoop) {
-      /// The yield is expected to producer the following structure
-      /// ```
-      /// %0 = scf.for ... iter_args(%arg0 = %init) {
-      ///   %1 = scf.for ... iter_args(%arg1 = %arg0) {
-      ///     scf.yield %yield
-      ///   }
-      /// }
-      /// ```
-      ///
-      /// since the yield is propagated from inside out, after the inner
-      /// loop is processed the IR is in this form
-      ///
-      /// ```
-      /// scf.for ... iter_args {
-      ///   %1 = scf.for ... iter_args(%arg1 = %init) {
-      ///     scf.yield %yield
-      ///   }
-      /// ```
-      ///
-      /// If `replaceIterOperandUsesInLoops` is true, there is nothing to do.
-      /// `%init` will be replaced with `%arg0` when it is created for the
-      /// outer loop. But without that this has to be done explicitly.
-      unsigned subLen = newIterOperands.size();
-      unsigned subStart =
-          newLoopNest[loopDepth + 1].getNumIterOperands() - subLen;
-      auto resetOperands =
-          newLoopNest[loopDepth + 1].getInitArgsMutable().slice(subStart,
-                                                                subLen);
-      resetOperands.assign(
-          newLoopNest[loopDepth].getRegionIterArgs().take_back(subLen));
-    }
+  // This method is recursive (to make it more readable). Adding an
+  // assertion here to limit the recursion. (See
+  // https://discourse.llvm.org/t/rfc-update-to-mlir-developer-policy-on-recursion/62235)
+  assert(loopNest.size() <= 6 &&
+         "exceeded recursion limit when yielding value from loop nest");
+
+  // To yield a value from a perfectly nested loop nest, the following
+  // pattern needs to be created, i.e. starting with
+  //
+  // ```mlir
+  //  scf.for .. {
+  //    scf.for .. {
+  //      scf.for .. {
+  //        %value = ...
+  //      }
+  //    }
+  //  }
+  // ```
+  //
+  // needs to be modified to
+  //
+  // ```mlir
+  // %0 = scf.for .. iter_args(%arg0 = %init) {
+  //   %1 = scf.for .. iter_args(%arg1 = %arg0) {
+  //     %2 = scf.for .. iter_args(%arg2 = %arg1) {
+  //       %value = ...
+  //       scf.yield %value
+  //     }
+  //     scf.yield %2
+  //   }
+  //   scf.yield %1
+  // }
+  // ```
+  //
+  // The inner most loop is handled using the `replaceLoopWithNewYields`
+  // that works on a single loop.
+  if (loopNest.size() == 1) {
+    auto innerMostLoop = replaceLoopWithNewYields(
+        builder, loopNest.back(), newIterOperands, newYieldValueFn,
+        replaceIterOperandsUsesInLoop);
+    return {innerMostLoop};
   }
+  // The outer loops are modified by calling this method recursively
+  // - The return value of the inner loop is the value yielded by this loop.
+  // - The region iter args of this loop are the init_args for the inner loop.
+  SmallVector<scf::ForOp> newLoopNest;
+  NewYieldValueFn fn =
+      [&](OpBuilder &innerBuilder, Location loc,
+          ArrayRef<BlockArgument> innerNewBBArgs) -> SmallVector<Value> {
+    newLoopNest = replaceLoopNestWithNewYields(builder, loopNest.drop_front(),
+                                               innerNewBBArgs, newYieldValueFn,
+                                               replaceIterOperandsUsesInLoop);
+    return llvm::to_vector(llvm::map_range(
+        newLoopNest.front().getResults().take_back(innerNewBBArgs.size()),
+        [](OpResult r) -> Value { return r; }));
+  };
+  scf::ForOp outerMostLoop =
+      replaceLoopWithNewYields(builder, loopNest.front(), newIterOperands, fn,
+                               replaceIterOperandsUsesInLoop);
+  newLoopNest.insert(newLoopNest.begin(), outerMostLoop);
   return newLoopNest;
 }
 


        


More information about the Mlir-commits mailing list