[Mlir-commits] [mlir] 9bdfa8d - [mlir][IR] Use Listener for IR callbacks in OperationFolder

Matthias Springer llvmlistbot at llvm.org
Thu Feb 23 00:02:53 PST 2023


Author: Matthias Springer
Date: 2023-02-23T08:56:43+01:00
New Revision: 9bdfa8df0db21845b8e1d8fc0fc8b70dfe25f45d

URL: https://github.com/llvm/llvm-project/commit/9bdfa8df0db21845b8e1d8fc0fc8b70dfe25f45d
DIFF: https://github.com/llvm/llvm-project/commit/9bdfa8df0db21845b8e1d8fc0fc8b70dfe25f45d.diff

LOG: [mlir][IR] Use Listener for IR callbacks in OperationFolder

Remove the IR modification callbacks from `OperationFolder`. Instead, an optional `RewriterBase::Listener` can be specified.
* `processGeneratedConstants` => `notifyOperationCreated`
* `preReplaceAction` => `notifyOperationReplaced`

This simplifies the GreedyPatternRewriterDriver because we no longer need special handling for IR modifications due to op folding.

A folded operation is now enqueued on the GreedyPatternRewriteDriver's worklist if it was modified in-place. (There may be new patterns that apply after folding.)

Also fixes a bug in `TestOpInPlaceFold::fold`. The folder could previously be applied over and over and did not return a "null" OpFoldResult if the IR was not modified. (This is similar to a pattern that returns `success` without modifying IR; it can trigger an infinite loop in the GreedyPatternRewriteDriver.)

Differential Revision: https://reviews.llvm.org/D144463

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/FoldUtils.h
    mlir/lib/Transforms/Utils/FoldUtils.cpp
    mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
    mlir/test/lib/Dialect/Test/TestDialect.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp
    mlir/test/lib/Transforms/TestConstantFold.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/FoldUtils.h b/mlir/include/mlir/Transforms/FoldUtils.h
index 2e5a80cbce35d..ff1083724e59c 100644
--- a/mlir/include/mlir/Transforms/FoldUtils.h
+++ b/mlir/include/mlir/Transforms/FoldUtils.h
@@ -17,6 +17,7 @@
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/Dialect.h"
 #include "mlir/IR/DialectInterface.h"
+#include "mlir/IR/PatternMatch.h"
 #include "mlir/Interfaces/FoldInterfaces.h"
 
 namespace mlir {
@@ -31,19 +32,14 @@ class Value;
 /// generated along the way.
 class OperationFolder {
 public:
-  OperationFolder(MLIRContext *ctx) : interfaces(ctx) {}
+  OperationFolder(MLIRContext *ctx, RewriterBase::Listener *listener = nullptr)
+      : interfaces(ctx), listener(listener) {}
 
   /// Tries to perform folding on the given `op`, including unifying
   /// deduplicated constants. If successful, replaces `op`'s uses with
-  /// folded results, and returns success. `preReplaceAction` is invoked on `op`
-  /// before it is replaced. 'processGeneratedConstants' is invoked for any new
-  /// operations generated when folding. If the op was completely folded it is
+  /// folded results, and returns success. If the op was completely folded it is
   /// erased. If it is just updated in place, `inPlaceUpdate` is set to true.
-  LogicalResult
-  tryToFold(Operation *op,
-            function_ref<void(Operation *)> processGeneratedConstants = nullptr,
-            function_ref<void(Operation *)> preReplaceAction = nullptr,
-            bool *inPlaceUpdate = nullptr);
+  LogicalResult tryToFold(Operation *op, bool *inPlaceUpdate = nullptr);
 
   /// Tries to fold a pre-existing constant operation. `constValue` represents
   /// the value of the constant, and can be optionally passed if the value is
@@ -122,23 +118,23 @@ class OperationFolder {
   using ConstantMap =
       DenseMap<std::tuple<Dialect *, Attribute, Type>, Operation *>;
 
+  /// Erase the given operation and notify the listener.
+  void eraseOp(Operation *op);
+
   /// Returns true if the given operation is an already folded constant that is
   /// owned by this folder.
   bool isFolderOwnedConstant(Operation *op) const;
 
   /// Tries to perform folding on the given `op`. If successful, populates
   /// `results` with the results of the folding.
-  LogicalResult tryToFold(
-      OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
-      function_ref<void(Operation *)> processGeneratedConstants = nullptr);
+  LogicalResult tryToFold(OpBuilder &builder, Operation *op,
+                          SmallVectorImpl<Value> &results);
 
   /// Try to process a set of fold results, generating constants as necessary.
   /// Populates `results` on success, otherwise leaves it unchanged.
-  LogicalResult
-  processFoldResults(OpBuilder &builder, Operation *op,
-                     SmallVectorImpl<Value> &results,
-                     ArrayRef<OpFoldResult> foldResults,
-                     function_ref<void(Operation *)> processGeneratedConstants);
+  LogicalResult processFoldResults(OpBuilder &builder, Operation *op,
+                                   SmallVectorImpl<Value> &results,
+                                   ArrayRef<OpFoldResult> foldResults);
 
   /// Try to get or create a new constant entry. On success this returns the
   /// constant operation, nullptr otherwise.
@@ -156,6 +152,9 @@ class OperationFolder {
 
   /// A collection of dialect folder interfaces.
   DialectInterfaceCollection<DialectFoldInterface> interfaces;
+
+  /// An optional listener that is notified of all IR changes.
+  RewriterBase::Listener *listener = nullptr;
 };
 
 } // namespace mlir

diff  --git a/mlir/lib/Transforms/Utils/FoldUtils.cpp b/mlir/lib/Transforms/Utils/FoldUtils.cpp
index fcfdbe4afab5f..22a488efb37ea 100644
--- a/mlir/lib/Transforms/Utils/FoldUtils.cpp
+++ b/mlir/lib/Transforms/Utils/FoldUtils.cpp
@@ -67,9 +67,7 @@ static Operation *materializeConstant(Dialect *dialect, OpBuilder &builder,
 // OperationFolder
 //===----------------------------------------------------------------------===//
 
-LogicalResult OperationFolder::tryToFold(
-    Operation *op, function_ref<void(Operation *)> processGeneratedConstants,
-    function_ref<void(Operation *)> preReplaceAction, bool *inPlaceUpdate) {
+LogicalResult OperationFolder::tryToFold(Operation *op, bool *inPlaceUpdate) {
   if (inPlaceUpdate)
     *inPlaceUpdate = false;
 
@@ -86,27 +84,26 @@ LogicalResult OperationFolder::tryToFold(
 
   // Try to fold the operation.
   SmallVector<Value, 8> results;
-  OpBuilder builder(op);
-  if (failed(tryToFold(builder, op, results, processGeneratedConstants)))
+  OpBuilder builder(op, listener);
+  if (failed(tryToFold(builder, op, results)))
     return failure();
 
   // Check to see if the operation was just updated in place.
   if (results.empty()) {
     if (inPlaceUpdate)
       *inPlaceUpdate = true;
+    if (listener)
+      listener->notifyOperationModified(op);
     return success();
   }
 
-  // Constant folding succeeded. We will start replacing this op's uses and
-  // erase this op. Invoke the callback provided by the caller to perform any
-  // pre-replacement action.
-  if (preReplaceAction)
-    preReplaceAction(op);
-
-  // Replace all of the result values and erase the operation.
+  // Constant folding succeeded. Replace all of the result values and erase the
+  // operation.
+  if (listener)
+    listener->notifyOperationReplaced(op, results);
   for (unsigned i = 0, e = results.size(); i != e; ++i)
     op->getResult(i).replaceAllUsesWith(results[i]);
-  op->erase();
+  eraseOp(op);
   return success();
 }
 
@@ -144,8 +141,10 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
 
   // If there is an existing constant, replace `op`.
   if (folderConstOp) {
+    if (listener)
+      listener->notifyOperationReplaced(op, folderConstOp->getResults());
     op->replaceAllUsesWith(folderConstOp);
-    op->erase();
+    eraseOp(op);
     return false;
   }
 
@@ -163,6 +162,13 @@ bool OperationFolder::insertKnownConstant(Operation *op, Attribute constValue) {
   return true;
 }
 
+void OperationFolder::eraseOp(Operation *op) {
+  notifyRemoval(op);
+  if (listener)
+    listener->notifyOperationRemoved(op);
+  op->erase();
+}
+
 /// Notifies that the given constant `op` should be remove from this
 /// OperationFolder's internal bookkeeping.
 void OperationFolder::notifyRemoval(Operation *op) {
@@ -221,9 +227,8 @@ bool OperationFolder::isFolderOwnedConstant(Operation *op) const {
 
 /// Tries to perform folding on the given `op`. If successful, populates
 /// `results` with the results of the folding.
-LogicalResult OperationFolder::tryToFold(
-    OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
-    function_ref<void(Operation *)> processGeneratedConstants) {
+LogicalResult OperationFolder::tryToFold(OpBuilder &builder, Operation *op,
+                                         SmallVectorImpl<Value> &results) {
   SmallVector<Attribute, 8> operandConstants;
 
   // If this is a commutative operation, move constants to be trailing operands.
@@ -252,16 +257,15 @@ LogicalResult OperationFolder::tryToFold(
   // fold.
   SmallVector<OpFoldResult, 8> foldResults;
   if (failed(op->fold(operandConstants, foldResults)) ||
-      failed(processFoldResults(builder, op, results, foldResults,
-                                processGeneratedConstants)))
+      failed(processFoldResults(builder, op, results, foldResults)))
     return success(updatedOpOperands);
   return success();
 }
 
-LogicalResult OperationFolder::processFoldResults(
-    OpBuilder &builder, Operation *op, SmallVectorImpl<Value> &results,
-    ArrayRef<OpFoldResult> foldResults,
-    function_ref<void(Operation *)> processGeneratedConstants) {
+LogicalResult
+OperationFolder::processFoldResults(OpBuilder &builder, Operation *op,
+                                    SmallVectorImpl<Value> &results,
+                                    ArrayRef<OpFoldResult> foldResults) {
   // Check to see if the operation was just updated in place.
   if (foldResults.empty())
     return success();
@@ -312,20 +316,13 @@ LogicalResult OperationFolder::processFoldResults(
     // If materialization fails, cleanup any operations generated for the
     // previous results and return failure.
     for (Operation &op : llvm::make_early_inc_range(
-             llvm::make_range(entry.begin(), builder.getInsertionPoint()))) {
-      notifyRemoval(&op);
-      op.erase();
-    }
+             llvm::make_range(entry.begin(), builder.getInsertionPoint())))
+      eraseOp(&op);
+
     results.clear();
     return failure();
   }
 
-  // Process any newly generated operations.
-  if (processGeneratedConstants) {
-    for (auto i = entry.begin(), e = builder.getInsertionPoint(); i != e; ++i)
-      processGeneratedConstants(&*i);
-  }
-
   return success();
 }
 
@@ -358,7 +355,7 @@ Operation *OperationFolder::tryGetOrCreateConstant(
   // If an existing operation in the new dialect already exists, delete the
   // materialized operation in favor of the existing one.
   if (auto *existingOp = uniquedConstants.lookup(newKey)) {
-    constOp->erase();
+    eraseOp(constOp);
     referencedDialects[existingOp].push_back(dialect);
     return constOp = existingOp;
   }

diff  --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 15495efca09af..a5977ccec1182 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -127,7 +127,8 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
 GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
     MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
     const GreedyRewriteConfig &config)
-    : PatternRewriter(ctx), folder(ctx), config(config), matcher(patterns) {
+    : PatternRewriter(ctx), folder(ctx, this), config(config),
+      matcher(patterns) {
   worklist.reserve(64);
 
   // Apply a simple cost model based solely on pattern benefit.
@@ -156,9 +157,6 @@ bool GreedyPatternRewriteDriver::processWorklist() {
   };
 #endif
 
-  // These are scratch vectors used in the folding loop below.
-  SmallVector<Value, 8> originalOperands;
-
   bool changed = false;
   int64_t numRewrites = 0;
   while (!worklist.empty() &&
@@ -197,34 +195,11 @@ bool GreedyPatternRewriteDriver::processWorklist() {
       continue;
     }
 
-    // Collects all the operands and result uses of the given `op` into work
-    // list. Also remove `op` and nested ops from worklist.
-    originalOperands.assign(op->operand_begin(), op->operand_end());
-    auto preReplaceAction = [&](Operation *op) {
-      // Add the operands to the worklist for visitation.
-      addOperandsToWorklist(originalOperands);
-
-      // Add all the users of the result to the worklist so we make sure
-      // to revisit them.
-      for (auto result : op->getResults())
-        for (auto *userOp : result.getUsers())
-          addToWorklist(userOp);
-
-      notifyOperationRemoved(op);
-    };
-
-    // Add the given operation to the worklist.
-    auto collectOps = [this](Operation *op) { addToWorklist(op); };
-
     // Try to fold this op.
-    bool inPlaceUpdate;
-    if ((succeeded(folder.tryToFold(op, collectOps, preReplaceAction,
-                                    &inPlaceUpdate)))) {
+    if (succeeded(folder.tryToFold(op))) {
       LLVM_DEBUG(logResultWithLine("success", "operation was folded"));
-
       changed = true;
-      if (!inPlaceUpdate)
-        continue;
+      continue;
     }
 
     // Try to match one of the patterns. The rewriter is automatically
@@ -465,7 +440,7 @@ LogicalResult RegionPatternRewriteDriver::simplify() && {
       // Add all nested operations to the worklist in preorder.
       region.walk<WalkOrder::PreOrder>([&](Operation *op) {
         if (!insertKnownConstant(op)) {
-          worklist.push_back(op);
+          addToWorklist(op);
           return WalkResult::advance();
         }
         return WalkResult::skip();

diff  --git a/mlir/test/lib/Dialect/Test/TestDialect.cpp b/mlir/test/lib/Dialect/Test/TestDialect.cpp
index dc5f629610d90..d27349710c451 100644
--- a/mlir/test/lib/Dialect/Test/TestDialect.cpp
+++ b/mlir/test/lib/Dialect/Test/TestDialect.cpp
@@ -1116,7 +1116,8 @@ LogicalResult TestOpWithVariadicResultsAndFolder::fold(
 }
 
 OpFoldResult TestOpInPlaceFold::fold(FoldAdaptor adaptor) {
-  if (adaptor.getOp()) {
+  if (adaptor.getOp() && !(*this)->hasAttr("attr")) {
+    // The folder adds "attr" if not present.
     (*this)->setAttr("attr", adaptor.getOp());
     return getResult();
   }

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index e9816ebcc13e8..b4447eb8cc142 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1280,7 +1280,7 @@ def TestOpInPlaceFoldAnchor : TEST_Op<"op_in_place_fold_anchor"> {
 }
 
 def TestOpInPlaceFold : TEST_Op<"op_in_place_fold"> {
-  let arguments = (ins I32:$op, I32Attr:$attr);
+  let arguments = (ins I32:$op, OptionalAttr<I32Attr>:$attr);
   let results = (outs I32);
   let hasFolder = 1;
 }

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 4bfbb3496ec3a..66c369d845bdf 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -93,8 +93,7 @@ struct FoldingPattern : public RewritePattern {
     // (unchanged) operation result.
     OperationFolder folder(op->getContext());
     Value result = folder.create<TestOpInPlaceFold>(
-        rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0),
-        rewriter.getI32IntegerAttr(0));
+        rewriter, op->getLoc(), rewriter.getIntegerType(32), op->getOperand(0));
     assert(result);
     rewriter.replaceOp(op, result);
     return success();

diff  --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp
index 1af923e014d52..9896cf372d269 100644
--- a/mlir/test/lib/Transforms/TestConstantFold.cpp
+++ b/mlir/test/lib/Transforms/TestConstantFold.cpp
@@ -13,8 +13,8 @@ using namespace mlir;
 
 namespace {
 /// Simple constant folding pass.
-struct TestConstantFold
-    : public PassWrapper<TestConstantFold, OperationPass<>> {
+struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
+                          public RewriterBase::Listener {
   MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestConstantFold)
 
   StringRef getArgument() const final { return "test-constant-fold"; }
@@ -26,17 +26,22 @@ struct TestConstantFold
 
   void foldOperation(Operation *op, OperationFolder &helper);
   void runOnOperation() override;
+
+  void notifyOperationInserted(Operation *op) override {
+    existingConstants.push_back(op);
+  }
+  void notifyOperationRemoved(Operation *op) override {
+    auto it = llvm::find(existingConstants, op);
+    if (it != existingConstants.end())
+      existingConstants.erase(it);
+  }
 };
 } // namespace
 
 void TestConstantFold::foldOperation(Operation *op, OperationFolder &helper) {
-  auto processGeneratedConstants = [this](Operation *op) {
-    existingConstants.push_back(op);
-  };
-
   // Attempt to fold the specified operation, including handling unused or
   // duplicated constants.
-  (void)helper.tryToFold(op, processGeneratedConstants);
+  (void)helper.tryToFold(op);
 }
 
 void TestConstantFold::runOnOperation() {
@@ -50,7 +55,7 @@ void TestConstantFold::runOnOperation() {
   // folding are at the beginning. This creates somewhat of a linear ordering to
   // the newly generated constants that matches the operation order and improves
   // the readability of test cases.
-  OperationFolder helper(&getContext());
+  OperationFolder helper(&getContext(), /*listener=*/this);
   for (Operation *op : llvm::reverse(ops))
     foldOperation(op, helper);
 


        


More information about the Mlir-commits mailing list