[llvm-branch-commits] [mlir] [mlir][Transforms] Add listener support to dialect conversion (PR #83425)

Matthias Springer via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Feb 29 05:06:42 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/83425

This commit adds listener support to the dialect conversion. Similarly to the greedy pattern rewrite driver, an optional listener can be specified in the configuration object.

Listeners are notified only if the dialect conversion succeeds. In case of a failure, where some IR changes are first performed and then rolled back, no notifications are sent.

Due to the fact that some kinds of rewrite are reflected in the IR immediately and some in a delayed fashion, there are certain limitations when attaching a listener; these are documented in `ConversionConfig`. To summarize, users are always notified about all rewrites that happened, but the notifications are sent all at once at the very end, and not interleaved with the actual IR changes.

This change is in preparation improvements to `transform.apply_conversion_patterns`, which currently invalidates all handles. In the future, it can use a listener to update handles accordingly, similar  to `transform.apply_patterns`.


>From 216e0b2d62e418cddcb6e4ecf6b07be361141131 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Thu, 29 Feb 2024 12:58:26 +0000
Subject: [PATCH] [mlir][Transforms] Add listener support to dialect conversion

---
 .../mlir/Transforms/DialectConversion.h       |  26 +++
 .../Transforms/Utils/DialectConversion.cpp    | 174 +++++++++++++-----
 mlir/test/Transforms/test-legalizer.mlir      |  71 ++++++-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  28 ++-
 4 files changed, 248 insertions(+), 51 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 84396529eb7c2e..98944b8a1ea648 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1091,6 +1091,32 @@ struct ConversionConfig {
   /// IR during an analysis conversion and only pre-existing operations are
   /// added to the set.
   DenseSet<Operation *> *legalizableOps = nullptr;
+
+  /// An optional listener that is notified about all IR modifications in case
+  /// dialect conversion succeeds. If the dialect conversion fails and no IR
+  /// modifications are visible (i.e., they were all rolled back), no
+  /// notifications are sent.
+  ///
+  /// Note: Notifications are sent in a delayed fashion, when the dialect
+  /// conversion is guaranteed to succeed. At that point, some IR modifications
+  /// may already have been materialized. Consequently, operations/blocks that
+  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
+  /// guaranteed to be valid pointers and accessing op names is allowed. But
+  /// there are no guarantees about the state of ops/blocks at the time that a
+  /// callback is triggered.)
+  ///
+  /// Example: Consider a dialect conversion a new op ("test.foo") is created
+  /// and inserted, and later moved to another block. (Moving ops also triggers
+  /// "notifyOperationInserted".)
+  ///
+  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
+  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
+  ///
+  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
+  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
+  /// immediately performed by the dialect conversion (and rolled back upon
+  /// failure).
+  RewriterBase::Listener *listener = nullptr;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 9f6468402686bd..0048347cae9314 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -203,14 +203,22 @@ class IRRewrite {
   /// Roll back the rewrite. Operations may be erased during rollback.
   virtual void rollback() = 0;
 
-  /// Commit the rewrite. Operations may be unlinked from their blocks during
-  /// the commit phase, but they must not be erased yet. This is because
-  /// internal dialect conversion state (such as `mapping`) may still be using
-  /// them. Operations must be erased during cleanup.
-  virtual void commit() {}
+  /// Commit the rewrite. At this point, it is certain that the dialect
+  /// conversion will succeed. All IR modifications, except for operation
+  /// erasure, must be performed through the given rewriter.
+  ///
+  /// Instead of erasing operations, they should merely be unlinked from their
+  /// blocks during the commit phase and finally be erased during the cleanup
+  /// phase. This is because internal dialect conversion state (such as
+  /// `mapping`) may still be using them.
+  ///
+  /// Any IR modification that was already performed before the commit phase
+  /// (e.g., insertion of an op) must be communicated to the listener that may
+  /// be attached to the given rewriter.
+  virtual void commit(RewriterBase &rewriter) {}
 
   /// Cleanup operations. Cleanup is called after commit.
-  virtual void cleanup() {}
+  virtual void cleanup(RewriterBase &rewriter) {}
 
   Kind getKind() const { return kind; }
 
@@ -220,12 +228,6 @@ class IRRewrite {
   IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
       : kind(kind), rewriterImpl(rewriterImpl) {}
 
-  /// Erase the given op (unless it was already erased).
-  void eraseOp(Operation *op);
-
-  /// Erase the given block (unless it was already erased).
-  void eraseBlock(Block *block);
-
   const ConversionConfig &getConfig() const;
 
   const Kind kind;
@@ -264,6 +266,12 @@ class CreateBlockRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::CreateBlock;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The block was already created and inserted. Just inform the listener.
+    if (auto *listener = rewriter.getListener())
+      listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
+  }
+
   void rollback() override {
     // Unlink all of the operations within this block, they will be deleted
     // separately.
@@ -309,10 +317,17 @@ class EraseBlockRewrite : public BlockRewrite {
     block = nullptr;
   }
 
-  void commit() override {
+  void commit(RewriterBase &rewriter) override {
     // Erase the block.
     assert(block && "expected block");
     assert(block->empty() && "expected empty block");
+
+    // Notify the listener that the block is about to be erased.
+    if (auto *listener =
+            dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+      listener->notifyBlockErased(block);
+
+    // Erase the block.
     block->dropAllDefinedValueUses();
     delete block;
     block = nullptr;
@@ -339,6 +354,16 @@ class MoveBlockRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::MoveBlock;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The block was already moved. Just inform the listener.
+    if (auto *listener = rewriter.getListener()) {
+      // Note: `previousIt` cannot be passed because this is a delayed
+      // notification and iterators into past IR state cannot be represented.
+      listener->notifyBlockInserted(block, /*previous=*/region,
+                                    /*previousIt=*/{});
+    }
+  }
+
   void rollback() override {
     // Move the block back to its original position.
     Region::iterator before =
@@ -394,7 +419,7 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   LogicalResult
   materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
@@ -423,7 +448,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::ReplaceBlockArg;
   }
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
@@ -463,6 +488,17 @@ class MoveOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::MoveOperation;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The operation was already moved. Just inform the listener.
+    if (auto *listener = rewriter.getListener()) {
+      // Note: `previousIt` cannot be passed because this is a delayed
+      // notification and iterators into past IR state cannot be represented.
+      listener->notifyOperationInserted(
+          op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
+                                                  /*insertPt=*/{}));
+    }
+  }
+
   void rollback() override {
     // Move the operation back to its original position.
     Block::iterator before =
@@ -506,7 +542,12 @@ class ModifyOperationRewrite : public OperationRewrite {
            "rewrite was neither committed nor rolled back");
   }
 
-  void commit() override {
+  void commit(RewriterBase &rewriter) override {
+    // Notify the listener that the operation was modified in-place.
+    if (auto *listener =
+            dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+      listener->notifyOperationModified(op);
+
     if (propertiesStorage) {
       OpaqueProperties propCopy(propertiesStorage);
       // Note: The operation may have been erased in the mean time, so
@@ -557,11 +598,11 @@ class ReplaceOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::ReplaceOperation;
   }
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
-  void cleanup() override;
+  void cleanup(RewriterBase &rewriter) override;
 
   const TypeConverter *getConverter() const { return converter; }
 
@@ -586,6 +627,12 @@ class CreateOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::CreateOperation;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The operation was already created and inserted. Just inform the listener.
+    if (auto *listener = rewriter.getListener())
+      listener->notifyOperationInserted(op, /*previous=*/{});
+  }
+
   void rollback() override;
 };
 
@@ -623,7 +670,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 
   void rollback() override;
 
-  void cleanup() override;
+  void cleanup(RewriterBase &rewriter) override;
 
   /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
@@ -692,7 +739,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : eraseRewriter(ctx), config(config) {}
+      : context(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -853,6 +900,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
     }
 
     void notifyOperationErased(Operation *op) override { erased.insert(op); }
+
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
     /// Pointers to all erased operations and blocks.
@@ -863,8 +911,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // State
   //===--------------------------------------------------------------------===//
 
-  /// This rewriter must be used for erasing ops/blocks.
-  SingleEraseRewriter eraseRewriter;
+  /// MLIR context.
+  MLIRContext *context;
 
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
@@ -908,19 +956,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 } // namespace detail
 } // namespace mlir
 
-void IRRewrite::eraseOp(Operation *op) {
-  rewriterImpl.eraseRewriter.eraseOp(op);
-}
-
-void IRRewrite::eraseBlock(Block *block) {
-  rewriterImpl.eraseRewriter.eraseBlock(block);
-}
-
 const ConversionConfig &IRRewrite::getConfig() const {
   return rewriterImpl.config;
 }
 
-void BlockTypeConversionRewrite::commit() {
+void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
+  // Inform the listener about all IR modifications that have already taken
+  // place: References to the original block have been replaced with the new
+  // block.
+  if (auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+          rewriter.getListener()))
+    for (Operation *op : block->getUsers())
+      listener->notifyOperationModified(op);
+
   // Process the remapping for each of the original arguments.
   for (auto [origArg, info] :
        llvm::zip_equal(origBlock->getArguments(), argInfo)) {
@@ -928,7 +976,7 @@ void BlockTypeConversionRewrite::commit() {
     if (!info) {
       if (Value newArg =
               rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-        origArg.replaceAllUsesWith(newArg);
+        rewriter.replaceAllUsesWith(origArg, newArg);
       continue;
     }
 
@@ -938,8 +986,8 @@ void BlockTypeConversionRewrite::commit() {
 
     // If the argument is still used, replace it with the generated cast.
     if (!origArg.use_empty()) {
-      origArg.replaceAllUsesWith(
-          rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
+      rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
+                                               castValue, origArg.getType()));
     }
   }
 }
@@ -995,13 +1043,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
   return success();
 }
 
-void ReplaceBlockArgRewrite::commit() {
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
   if (!repl)
     return;
 
   if (isa<BlockArgument>(repl)) {
-    arg.replaceAllUsesWith(repl);
+    rewriter.replaceAllUsesWith(arg, repl);
     return;
   }
 
@@ -1010,7 +1058,7 @@ void ReplaceBlockArgRewrite::commit() {
   // replacement value.
   Operation *replOp = cast<OpResult>(repl).getOwner();
   Block *replBlock = replOp->getBlock();
-  arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
   });
@@ -1018,14 +1066,40 @@ void ReplaceBlockArgRewrite::commit() {
 
 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
 
-void ReplaceOperationRewrite::commit() {
-  for (OpResult result : op->getResults())
-    if (Value newValue =
-            rewriterImpl.mapping.lookupOrNull(result, result.getType()))
-      result.replaceAllUsesWith(newValue);
+void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
+  auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+      rewriter.getListener());
+
+  // Compute replacement values.
+  SmallVector<Value> replacements =
+      llvm::map_to_vector(op->getResults(), [&](OpResult result) {
+        return rewriterImpl.mapping.lookupOrNull(result, result.getType());
+      });
+
+  // Notify the listener that the operation is about to be replaced.
+  if (listener)
+    listener->notifyOperationReplaced(op, replacements);
+
+  // Replace all uses with the new values.
+  for (auto [result, newValue] :
+       llvm::zip_equal(op->getResults(), replacements))
+    if (newValue)
+      rewriter.replaceAllUsesWith(result, newValue);
+
+  // The original op will be erased, so remove it from the set of unlegalized
+  // ops.
   if (getConfig().unlegalizedOps)
     getConfig().unlegalizedOps->erase(op);
+
+  // Notify the listener that the operation (and its nested operations) was
+  // erased.
+  if (listener) {
+    op->walk<WalkOrder::PostOrder>(
+        [&](Operation *op) { listener->notifyOperationErased(op); });
+  }
+
   // Do not erase the operation yet. It may still be referenced in `mapping`.
+  // Just unlink it for now and erase it during cleanup.
   op->getBlock()->getOperations().remove(op);
 }
 
@@ -1034,7 +1108,9 @@ void ReplaceOperationRewrite::rollback() {
     rewriterImpl.mapping.erase(result);
 }
 
-void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
+  rewriter.eraseOp(op);
+}
 
 void CreateOperationRewrite::rollback() {
   for (Region &region : op->getRegions()) {
@@ -1053,14 +1129,20 @@ void UnresolvedMaterializationRewrite::rollback() {
   op->erase();
 }
 
-void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
+void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
+  rewriter.eraseOp(op);
+}
 
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
+  IRRewriter rewriter(context, config.listener);
   for (auto &rewrite : rewrites)
-    rewrite->commit();
+    rewrite->commit(rewriter);
+
+  // Clean up all rewrites.
+  SingleEraseRewriter eraseRewriter(context);
   for (auto &rewrite : rewrites)
-    rewrite->cleanup();
+    rewrite->cleanup(eraseRewriter);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 62d776cd7573ee..8af8102adf9754 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -1,5 +1,10 @@
 // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics %s | FileCheck %s
 
+//      CHECK: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_a
+
 // CHECK-LABEL: verifyDirectPattern
 func.func @verifyDirectPattern() -> i32 {
   // CHECK-NEXT:  "test.legal_op_a"() <{status = "Success"}
@@ -8,6 +13,16 @@ func.func @verifyDirectPattern() -> i32 {
   return %result : i32
 }
 
+// -----
+
+//      CHECK: notifyOperationInserted: test.illegal_op_e, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_c
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_c
+// CHECK-NEXT: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_e
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_e
+
 // CHECK-LABEL: verifyLargerBenefit
 func.func @verifyLargerBenefit() -> i32 {
   // CHECK-NEXT:  "test.legal_op_a"() <{status = "Success"}
@@ -16,16 +31,24 @@ func.func @verifyLargerBenefit() -> i32 {
   return %result : i32
 }
 
+// -----
+
+// CHECK: notifyOperationModified: func.func
+// Note: No block insertion because this function is external and block
+// signature conversion is performed.
+
 // CHECK-LABEL: func private @remap_input_1_to_0()
 func.func private @remap_input_1_to_0(i16)
 
+// -----
+
 // CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64)
 func.func @remap_input_1_to_1(%arg0: i64) {
   // CHECK-NEXT: "test.valid"{{.*}} : (f64)
   "test.invalid"(%arg0) : (i64) -> ()
 }
 
-// CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64)
+// CHECK: func @remap_call_1_to_1(%arg0: f64)
 func.func @remap_call_1_to_1(%arg0: i64) {
   // CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> ()
   call @remap_input_1_to_1(%arg0) : (i64) -> ()
@@ -33,12 +56,36 @@ func.func @remap_call_1_to_1(%arg0: i64) {
   return
 }
 
+// -----
+
+// Block signature conversion: new block is inserted.
+// CHECK:      notifyBlockInserted into func.func: was unlinked
+
+// Contents of the old block are moved to the new block.
+// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
+
+// The new block arguments are used in "test.return".
+// CHECK-NEXT: notifyOperationModified: test.return
+
+// The old block is erased.
+// CHECK-NEXT: notifyBlockErased
+
+// The function op gets a new type attribute.
+// CHECK-NEXT: notifyOperationModified: func.func
+
+// "test.return" is replaced.
+// CHECK-NEXT: notifyOperationInserted: test.return, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.return
+// CHECK-NEXT: notifyOperationErased: test.return
+
 // CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16)
 func.func @remap_input_1_to_N(%arg0: f32) -> f32 {
   // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> ()
   "test.return"(%arg0) : (f32) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16)
 func.func @remap_input_1_to_N_remaining_use(%arg0: f32) {
   // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32
@@ -54,6 +101,8 @@ func.func @remap_materialize_1_to_1(%arg0: i42) {
   "test.return"(%arg0) : (i42) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_input_to_self
 func.func @remap_input_to_self(%arg0: index) {
   // CHECK-NOT: test.cast
@@ -68,6 +117,8 @@ func.func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
  "test.invalid"(%arg0, %arg1) : (i64, i64) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @no_remap_nested
 func.func @no_remap_nested() {
   // CHECK-NEXT: "foo.region"
@@ -82,6 +133,8 @@ func.func @no_remap_nested() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_moved_region_args
 func.func @remap_moved_region_args() {
   // CHECK-NEXT: return
@@ -96,6 +149,8 @@ func.func @remap_moved_region_args() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_cloned_region_args
 func.func @remap_cloned_region_args() {
   // CHECK-NEXT: return
@@ -122,6 +177,8 @@ func.func @remap_drop_region() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @dropped_input_in_use
 func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
   // CHECK-NEXT: "test.cast"{{.*}} : () -> i16
@@ -130,6 +187,8 @@ func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
   "work"(%arg) : (i16) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @up_to_date_replacement
 func.func @up_to_date_replacement(%arg: i8) -> i8 {
   // CHECK-NEXT: return
@@ -139,6 +198,8 @@ func.func @up_to_date_replacement(%arg: i8) -> i8 {
   return %repl_2 : i8
 }
 
+// -----
+
 // CHECK-LABEL: func @remove_foldable_op
 // CHECK-SAME:                          (%[[ARG_0:[a-z0-9]*]]: i32)
 func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
@@ -150,6 +211,8 @@ func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
   return %0 : i32
 }
 
+// -----
+
 // CHECK-LABEL: @create_block
 func.func @create_block() {
   // Check that we created a block with arguments.
@@ -161,6 +224,12 @@ func.func @create_block() {
   return
 }
 
+// -----
+
+//      CHECK: notifyOperationModified: test.recursive_rewrite
+// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite
+// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite
+
 // CHECK-LABEL: @bounded_recursion
 func.func @bounded_recursion() {
   // CHECK: test.recursive_rewrite 0
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index abc0e43c7b7f2d..2628a2784c18da 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -327,8 +327,12 @@ struct TestPatternDriver
 struct DumpNotifications : public RewriterBase::Listener {
   void notifyBlockInserted(Block *block, Region *previous,
                            Region::iterator previousIt) override {
-    llvm::outs() << "notifyBlockInserted into "
-                 << block->getParentOp()->getName() << ": ";
+    llvm::outs() << "notifyBlockInserted";
+    if (block->getParentOp()) {
+      llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
+    } else {
+      llvm::outs() << " into unknown op: ";
+    }
     if (previous == nullptr) {
       llvm::outs() << "was unlinked\n";
     } else {
@@ -341,7 +345,9 @@ struct DumpNotifications : public RewriterBase::Listener {
     if (!previous.isSet()) {
       llvm::outs() << ", was unlinked\n";
     } else {
-      if (previous.getPoint() == previous.getBlock()->end()) {
+      if (!previous.getPoint().getNodePtr()) {
+        llvm::outs() << ", was linked, exact position unknown\n";
+      } else if (previous.getPoint() == previous.getBlock()->end()) {
         llvm::outs() << ", was last in block\n";
       } else {
         llvm::outs() << ", previous = " << previous.getPoint()->getName()
@@ -349,9 +355,18 @@ struct DumpNotifications : public RewriterBase::Listener {
       }
     }
   }
+  void notifyBlockErased(Block *block) override {
+    llvm::outs() << "notifyBlockErased\n";
+  }
   void notifyOperationErased(Operation *op) override {
     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
   }
+  void notifyOperationModified(Operation *op) override {
+    llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
+  }
+  void notifyOperationReplaced(Operation *op, ValueRange values) override {
+    llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
+  }
 };
 
 struct TestStrictPatternDriver
@@ -1153,6 +1168,8 @@ struct TestLegalizePatternDriver
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;
       ConversionConfig config;
+      DumpNotifications dumpNotifications;
+      config.listener = &dumpNotifications;
       config.unlegalizedOps = &unlegalizedOps;
       if (failed(applyPartialConversion(getOperation(), target,
                                         std::move(patterns), config))) {
@@ -1171,8 +1188,11 @@ struct TestLegalizePatternDriver
         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
       });
 
+      ConversionConfig config;
+      DumpNotifications dumpNotifications;
+      config.listener = &dumpNotifications;
       if (failed(applyFullConversion(getOperation(), target,
-                                     std::move(patterns)))) {
+                                     std::move(patterns), config))) {
         getOperation()->emitRemark() << "applyFullConversion failed";
       }
       return;



More information about the llvm-branch-commits mailing list