[Mlir-commits] [mlir] cbcb12f - [MLIR] Handle in-place folding properly in greedy pattern rewrite driver

Uday Bondhugula llvmlistbot at llvm.org
Sat Apr 11 07:28:05 PDT 2020


Author: Uday Bondhugula
Date: 2020-04-11T19:57:29+05:30
New Revision: cbcb12fd44dfdb51bbf4489d213d96f17be3091f

URL: https://github.com/llvm/llvm-project/commit/cbcb12fd44dfdb51bbf4489d213d96f17be3091f
DIFF: https://github.com/llvm/llvm-project/commit/cbcb12fd44dfdb51bbf4489d213d96f17be3091f.diff

LOG: [MLIR] Handle in-place folding properly in greedy pattern rewrite driver

OperatioFolder::tryToFold performs both true folding and in a few
instances in-place updates through op rewrites. In the latter case, we
should still be applying the supplied pattern rewrites in the same
iteration; however this wasn't the case since tryToFold returned
success() for both true folding and in-place updates, and the patterns
for the in-place updated ops were being applied only in the next
iteration of the driver's outer loop. This fix would make it converge
faster.

Differential Revision: https://reviews.llvm.org/D77485

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/FoldUtils.h
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 0bab87c5e4e3..d2ba43339ce3 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -56,11 +56,12 @@ class OperationFolder {
   /// folded results, and returns success. `preReplaceAction` is invoked on `op`
   /// before it is replaced. 'processGeneratedConstants' is invoked for any new
   /// operations generated when folding. If the op was completely folded it is
-  /// erased.
+  /// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
   LogicalResult
   tryToFold(Operation *op,
             function_ref<void(Operation *)> processGeneratedConstants = nullptr,
-            function_ref<void(Operation *)> preReplaceAction = nullptr);
+            function_ref<void(Operation *)> preReplaceAction = nullptr,
+            bool *inPlaceUpdate = nullptr);
 
   /// Notifies that the given constant `op` should be remove from this
   /// OperationFolder's internal bookkeeping.

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index f2099bca75ea..9e67c2b6b348 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -74,7 +74,10 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
 
 LogicalResult OperationFolder::tryToFold(
     Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
-    function_ref<void(Operation *)> preReplaceAction) {
+    function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
+  if (inPlaceUpdate)
+    *inPlaceUpdate = false;
+
   // If this is a unique'd constant, return failure as we know that it has
   // already been folded.
   if (referencedDialects.count(op))
@@ -87,8 +90,11 @@ LogicalResult OperationFolder::tryToFold(
     return failure();
 
   // Check to see if the operation was just updated in place.
-  if (results.empty())
+  if (results.empty()) {
+    if (inPlaceUpdate)
+      *inPlaceUpdate = true;
     return success();
+  }
 
   // Constant folding succeeded. We will start replacing this op's uses and
   // erase this op. Invoke the callback provided by the caller to perform any

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 80ad143ce0d3..53c8e9fbd1c2 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -104,7 +104,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
   // be re-added to the worklist. This function should be called when an
   // operation is modified or removed, as it may trigger further
   // simplifications.
-  template <typename Operands> void addToWorklist(Operands &&operands) {
+  template <typename Operands>
+  void addToWorklist(Operands &&operands) {
     for (Value operand : operands) {
       // If the use count of this operand is now < 2, we re-add the defining
       // operation to the worklist.
@@ -133,7 +134,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter {
 };
 } // end anonymous namespace
 
-/// Perform the rewrites while folding and erasing any dead ops.
+/// Performs the rewrites while folding and erasing any dead ops. Returns true
+/// if the rewrite converges in `maxIterations`.
 bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
                                           int maxIterations) {
   // Add the given operation to the worklist.
@@ -183,9 +185,12 @@ bool GreedyPatternRewriteDriver::simplify(MutableArrayRef<Region> regions,
       };
 
       // Try to fold this op.
-      if (succeeded(folder.tryToFold(op, collectOps, preReplaceAction))) {
+      bool inPlaceUpdate;
+      if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
+                                      &inPlaceUpdate)))) {
         changed = true;
-        continue;
+        if (!inPlaceUpdate)
+          continue;
       }
 
       // Make sure that any new operations are inserted at this point.


        


More information about the Mlir-commits mailing list