[Mlir-commits] [mlir] [mlir][Transforms] Add listener support to dialect conversion (PR #83425)

Matthias Springer llvmlistbot at llvm.org
Thu Mar 7 17:13:53 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/83425

>From d91a0fb6cb8a8559588022b954994af99140833d Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Fri, 8 Mar 2024 01:09:56 +0000
Subject: [PATCH] [mlir][Transforms] Add listener support to dialect conversion

---
 .../mlir/Transforms/DialectConversion.h       |  33 +++
 .../Transforms/Utils/DialectConversion.cpp    | 225 ++++++++++++++----
 mlir/test/Transforms/test-legalizer.mlir      |  71 +++++-
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  28 ++-
 4 files changed, 302 insertions(+), 55 deletions(-)

diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 01fde101ef3cb6d..83198c9b0db5455 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1085,6 +1085,39 @@ struct ConversionConfig {
   /// IR during an analysis conversion and only pre-existing operations are
   /// added to the set.
   DenseSet<Operation *> *legalizableOps = nullptr;
+
+  /// An optional listener that is notified about all IR modifications in case
+  /// dialect conversion succeeds. If the dialect conversion fails and no IR
+  /// modifications are visible (i.e., they were all rolled back), no
+  /// notifications are sent.
+  ///
+  /// Note: Notifications are sent in a delayed fashion, when the dialect
+  /// conversion is guaranteed to succeed. At that point, some IR modifications
+  /// may already have been materialized. Consequently, operations/blocks that
+  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
+  /// guaranteed to be valid pointers and accessing op names is allowed. But
+  /// there are no guarantees about the state of ops/blocks at the time that a
+  /// callback is triggered.)
+  ///
+  /// Example: Consider a dialect conversion a new op ("test.foo") is created
+  /// and inserted, and later moved to another block. (Moving ops also triggers
+  /// "notifyOperationInserted".)
+  ///
+  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
+  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
+  ///
+  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
+  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
+  /// immediately performed by the dialect conversion (and rolled back upon
+  /// failure).
+  //
+  // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
+  // callback, the previous region/block is provided to the callback, but not
+  // the iterator pointing to the exact location within the region/block. That
+  // is because these notifications are sent with a delay (after the IR has
+  // already been modified) and iterators into past IR state cannot be
+  // represented at the moment.
+  RewriterBase::Listener *listener = nullptr;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8b2d71408a56516..c1a261eab8487d5 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -204,14 +204,22 @@ class IRRewrite {
   /// Roll back the rewrite. Operations may be erased during rollback.
   virtual void rollback() = 0;
 
-  /// Commit the rewrite. Operations/blocks may be unlinked during the commit
-  /// phase, but they must not be erased yet. This is because internal dialect
-  /// conversion state (such as `mapping`) may still be using them. Operations/
-  /// blocks must be erased during cleanup.
-  virtual void commit() {}
+  /// Commit the rewrite. At this point, it is certain that the dialect
+  /// conversion will succeed. All IR modifications, except for operation/block
+  /// erasure, must be performed through the given rewriter.
+  ///
+  /// Instead of erasing operations/blocks, they should merely be unlinked
+  /// commit phase and finally be erased during the cleanup phase. This is
+  /// because internal dialect conversion state (such as `mapping`) may still
+  /// be using them.
+  ///
+  /// Any IR modification that was already performed before the commit phase
+  /// (e.g., insertion of an op) must be communicated to the listener that may
+  /// be attached to the given rewriter.
+  virtual void commit(RewriterBase &rewriter) {}
 
   /// Cleanup operations/blocks. Cleanup is called after commit.
-  virtual void cleanup() {}
+  virtual void cleanup(RewriterBase &rewriter) {}
 
   Kind getKind() const { return kind; }
 
@@ -221,12 +229,6 @@ class IRRewrite {
   IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
       : kind(kind), rewriterImpl(rewriterImpl) {}
 
-  /// Erase the given op (unless it was already erased).
-  void eraseOp(Operation *op);
-
-  /// Erase the given block (unless it was already erased).
-  void eraseBlock(Block *block);
-
   const ConversionConfig &getConfig() const;
 
   const Kind kind;
@@ -265,6 +267,12 @@ class CreateBlockRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::CreateBlock;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The block was already created and inserted. Just inform the listener.
+    if (auto *listener = rewriter.getListener())
+      listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
+  }
+
   void rollback() override {
     // Unlink all of the operations within this block, they will be deleted
     // separately.
@@ -311,10 +319,19 @@ class EraseBlockRewrite : public BlockRewrite {
     block = nullptr;
   }
 
-  void cleanup() override {
+  void commit(RewriterBase &rewriter) override {
     // Erase the block.
     assert(block && "expected block");
     assert(block->empty() && "expected empty block");
+
+    // Notify the listener that the block is about to be erased.
+    if (auto *listener =
+            dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+      listener->notifyBlockErased(block);
+  }
+
+  void cleanup(RewriterBase &rewriter) override {
+    // Erase the block.
     block->dropAllDefinedValueUses();
     delete block;
     block = nullptr;
@@ -341,6 +358,13 @@ class InlineBlockRewrite : public BlockRewrite {
         firstInlinedInst(sourceBlock->empty() ? nullptr
                                               : &sourceBlock->front()),
         lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
+    // If a listener is attached to the dialect conversion, ops must be moved
+    // one-by-one. When they are moved in bulk, notifications cannot be sent
+    // because the ops that used to be in the source block at the time of the
+    // inlining (before the "commit" phase) are unknown at the time when
+    // notifications are sent (which is during the "commit" phase).
+    assert(!getConfig().listener &&
+           "InlineBlockRewrite not supported if listener is attached");
   }
 
   static bool classof(const IRRewrite *rewrite) {
@@ -382,6 +406,16 @@ class MoveBlockRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::MoveBlock;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The block was already moved. Just inform the listener.
+    if (auto *listener = rewriter.getListener()) {
+      // Note: `previousIt` cannot be passed because this is a delayed
+      // notification and iterators into past IR state cannot be represented.
+      listener->notifyBlockInserted(block, /*previous=*/region,
+                                    /*previousIt=*/{});
+    }
+  }
+
   void rollback() override {
     // Move the block back to its original position.
     Region::iterator before =
@@ -437,7 +471,7 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   LogicalResult
   materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
@@ -466,7 +500,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::ReplaceBlockArg;
   }
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
@@ -506,6 +540,17 @@ class MoveOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::MoveOperation;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The operation was already moved. Just inform the listener.
+    if (auto *listener = rewriter.getListener()) {
+      // Note: `previousIt` cannot be passed because this is a delayed
+      // notification and iterators into past IR state cannot be represented.
+      listener->notifyOperationInserted(
+          op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
+                                                  /*insertPt=*/{}));
+    }
+  }
+
   void rollback() override {
     // Move the operation back to its original position.
     Block::iterator before =
@@ -549,7 +594,12 @@ class ModifyOperationRewrite : public OperationRewrite {
            "rewrite was neither committed nor rolled back");
   }
 
-  void commit() override {
+  void commit(RewriterBase &rewriter) override {
+    // Notify the listener that the operation was modified in-place.
+    if (auto *listener =
+            dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+      listener->notifyOperationModified(op);
+
     if (propertiesStorage) {
       OpaqueProperties propCopy(propertiesStorage);
       // Note: The operation may have been erased in the mean time, so
@@ -600,11 +650,11 @@ class ReplaceOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::ReplaceOperation;
   }
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
-  void cleanup() override;
+  void cleanup(RewriterBase &rewriter) override;
 
   const TypeConverter *getConverter() const { return converter; }
 
@@ -629,6 +679,12 @@ class CreateOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::CreateOperation;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The operation was already created and inserted. Just inform the listener.
+    if (auto *listener = rewriter.getListener())
+      listener->notifyOperationInserted(op, /*previous=*/{});
+  }
+
   void rollback() override;
 };
 
@@ -666,7 +722,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 
   void rollback() override;
 
-  void cleanup() override;
+  void cleanup(RewriterBase &rewriter) override;
 
   /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
@@ -735,7 +791,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : eraseRewriter(ctx), config(config) {}
+      : context(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -900,6 +956,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
     }
 
     void notifyOperationErased(Operation *op) override { erased.insert(op); }
+
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
     /// Pointers to all erased operations and blocks.
@@ -910,8 +967,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // State
   //===--------------------------------------------------------------------===//
 
-  /// This rewriter must be used for erasing ops/blocks.
-  SingleEraseRewriter eraseRewriter;
+  /// MLIR context.
+  MLIRContext *context;
 
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
@@ -955,19 +1012,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 } // namespace detail
 } // namespace mlir
 
-void IRRewrite::eraseOp(Operation *op) {
-  rewriterImpl.eraseRewriter.eraseOp(op);
-}
-
-void IRRewrite::eraseBlock(Block *block) {
-  rewriterImpl.eraseRewriter.eraseBlock(block);
-}
-
 const ConversionConfig &IRRewrite::getConfig() const {
   return rewriterImpl.config;
 }
 
-void BlockTypeConversionRewrite::commit() {
+void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
+  // Inform the listener about all IR modifications that have already taken
+  // place: References to the original block have been replaced with the new
+  // block.
+  if (auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+          rewriter.getListener()))
+    for (Operation *op : block->getUsers())
+      listener->notifyOperationModified(op);
+
   // Process the remapping for each of the original arguments.
   for (auto [origArg, info] :
        llvm::zip_equal(origBlock->getArguments(), argInfo)) {
@@ -975,7 +1032,7 @@ void BlockTypeConversionRewrite::commit() {
     if (!info) {
       if (Value newArg =
               rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-        origArg.replaceAllUsesWith(newArg);
+        rewriter.replaceAllUsesWith(origArg, newArg);
       continue;
     }
 
@@ -985,8 +1042,8 @@ void BlockTypeConversionRewrite::commit() {
 
     // If the argument is still used, replace it with the generated cast.
     if (!origArg.use_empty()) {
-      origArg.replaceAllUsesWith(
-          rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
+      rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
+                                               castValue, origArg.getType()));
     }
   }
 }
@@ -1042,13 +1099,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
   return success();
 }
 
-void ReplaceBlockArgRewrite::commit() {
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
   if (!repl)
     return;
 
   if (isa<BlockArgument>(repl)) {
-    arg.replaceAllUsesWith(repl);
+    rewriter.replaceAllUsesWith(arg, repl);
     return;
   }
 
@@ -1057,7 +1114,7 @@ void ReplaceBlockArgRewrite::commit() {
   // replacement value.
   Operation *replOp = cast<OpResult>(repl).getOwner();
   Block *replBlock = replOp->getBlock();
-  arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
   });
@@ -1065,14 +1122,40 @@ void ReplaceBlockArgRewrite::commit() {
 
 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
 
-void ReplaceOperationRewrite::commit() {
-  for (OpResult result : op->getResults())
-    if (Value newValue =
-            rewriterImpl.mapping.lookupOrNull(result, result.getType()))
-      result.replaceAllUsesWith(newValue);
+void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
+  auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+      rewriter.getListener());
+
+  // Compute replacement values.
+  SmallVector<Value> replacements =
+      llvm::map_to_vector(op->getResults(), [&](OpResult result) {
+        return rewriterImpl.mapping.lookupOrNull(result, result.getType());
+      });
+
+  // Notify the listener that the operation is about to be replaced.
+  if (listener)
+    listener->notifyOperationReplaced(op, replacements);
+
+  // Replace all uses with the new values.
+  for (auto [result, newValue] :
+       llvm::zip_equal(op->getResults(), replacements))
+    if (newValue)
+      rewriter.replaceAllUsesWith(result, newValue);
+
+  // The original op will be erased, so remove it from the set of unlegalized
+  // ops.
   if (getConfig().unlegalizedOps)
     getConfig().unlegalizedOps->erase(op);
+
+  // Notify the listener that the operation (and its nested operations) was
+  // erased.
+  if (listener) {
+    op->walk<WalkOrder::PostOrder>(
+        [&](Operation *op) { listener->notifyOperationErased(op); });
+  }
+
   // Do not erase the operation yet. It may still be referenced in `mapping`.
+  // Just unlink it for now and erase it during cleanup.
   op->getBlock()->getOperations().remove(op);
 }
 
@@ -1081,7 +1164,9 @@ void ReplaceOperationRewrite::rollback() {
     rewriterImpl.mapping.erase(result);
 }
 
-void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
+  rewriter.eraseOp(op);
+}
 
 void CreateOperationRewrite::rollback() {
   for (Region &region : op->getRegions()) {
@@ -1100,14 +1185,20 @@ void UnresolvedMaterializationRewrite::rollback() {
   op->erase();
 }
 
-void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
+void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
+  rewriter.eraseOp(op);
+}
 
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
+  IRRewriter rewriter(context, config.listener);
   for (auto &rewrite : rewrites)
-    rewrite->commit();
+    rewrite->commit(rewriter);
+
+  // Clean up all rewrites.
+  SingleEraseRewriter eraseRewriter(context);
   for (auto &rewrite : rewrites)
-    rewrite->cleanup();
+    rewrite->cleanup(eraseRewriter);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1306,8 +1397,21 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   Block *newBlock =
       rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
                            convertedTypes, newLocs);
-  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
-  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
+
+  // If a listener is attached to the dialect conversion, ops cannot be moved
+  // to the destination block in bulk ("fast path"). This is because at the time
+  // the notifications are sent, it is unknown which ops were moved. Instead,
+  // ops should be moved one-by-one ("slow path"), so that a separate
+  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
+  // a bit more efficient, so we try to do that when possible.
+  bool fastPath = !config.listener;
+  if (fastPath) {
+    appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+    newBlock->getOperations().splice(newBlock->end(), block->getOperations());
+  } else {
+    while (!block->empty())
+      rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
+  }
 
   // Replace all uses of the old block with the new block.
   block->replaceAllUsesWith(newBlock);
@@ -1645,10 +1749,31 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
          "expected 'source' to have no predecessors");
 #endif // NDEBUG
 
-  impl->notifyBlockBeingInlined(dest, source, before);
+  // If a listener is attached to the dialect conversion, ops cannot be moved
+  // to the destination block in bulk ("fast path"). This is because at the time
+  // the notifications are sent, it is unknown which ops were moved. Instead,
+  // ops should be moved one-by-one ("slow path"), so that a separate
+  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
+  // a bit more efficient, so we try to do that when possible.
+  bool fastPath = !impl->config.listener;
+
+  if (fastPath)
+    impl->notifyBlockBeingInlined(dest, source, before);
+
+  // Replace all uses of block arguments.
   for (auto it : llvm::zip(source->getArguments(), argValues))
     replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
-  dest->getOperations().splice(before, source->getOperations());
+
+  if (fastPath) {
+    // Move all ops at once.
+    dest->getOperations().splice(before, source->getOperations());
+  } else {
+    // Move op by op.
+    while (!source->empty())
+      moveOpBefore(&source->front(), dest, before);
+  }
+
+  // Erase the source block.
   eraseBlock(source);
 }
 
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index ccdc9fe78ea0d37..d552f0346644b3f 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -1,5 +1,10 @@
 // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics %s | FileCheck %s
 
+//      CHECK: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_a
+
 // CHECK-LABEL: verifyDirectPattern
 func.func @verifyDirectPattern() -> i32 {
   // CHECK-NEXT:  "test.legal_op_a"() <{status = "Success"}
@@ -8,6 +13,16 @@ func.func @verifyDirectPattern() -> i32 {
   return %result : i32
 }
 
+// -----
+
+//      CHECK: notifyOperationInserted: test.illegal_op_e, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_c
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_c
+// CHECK-NEXT: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_e
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_e
+
 // CHECK-LABEL: verifyLargerBenefit
 func.func @verifyLargerBenefit() -> i32 {
   // CHECK-NEXT:  "test.legal_op_a"() <{status = "Success"}
@@ -16,16 +31,24 @@ func.func @verifyLargerBenefit() -> i32 {
   return %result : i32
 }
 
+// -----
+
+// CHECK: notifyOperationModified: func.func
+// Note: No block insertion because this function is external and no block
+// signature conversion is performed.
+
 // CHECK-LABEL: func private @remap_input_1_to_0()
 func.func private @remap_input_1_to_0(i16)
 
+// -----
+
 // CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64)
 func.func @remap_input_1_to_1(%arg0: i64) {
   // CHECK-NEXT: "test.valid"{{.*}} : (f64)
   "test.invalid"(%arg0) : (i64) -> ()
 }
 
-// CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64)
+// CHECK: func @remap_call_1_to_1(%arg0: f64)
 func.func @remap_call_1_to_1(%arg0: i64) {
   // CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> ()
   call @remap_input_1_to_1(%arg0) : (i64) -> ()
@@ -33,12 +56,36 @@ func.func @remap_call_1_to_1(%arg0: i64) {
   return
 }
 
+// -----
+
+// Block signature conversion: new block is inserted.
+// CHECK:      notifyBlockInserted into func.func: was unlinked
+
+// Contents of the old block are moved to the new block.
+// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
+
+// The new block arguments are used in "test.return".
+// CHECK-NEXT: notifyOperationModified: test.return
+
+// The old block is erased.
+// CHECK-NEXT: notifyBlockErased
+
+// The function op gets a new type attribute.
+// CHECK-NEXT: notifyOperationModified: func.func
+
+// "test.return" is replaced.
+// CHECK-NEXT: notifyOperationInserted: test.return, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.return
+// CHECK-NEXT: notifyOperationErased: test.return
+
 // CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16)
 func.func @remap_input_1_to_N(%arg0: f32) -> f32 {
   // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> ()
   "test.return"(%arg0) : (f32) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16)
 func.func @remap_input_1_to_N_remaining_use(%arg0: f32) {
   // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32
@@ -54,6 +101,8 @@ func.func @remap_materialize_1_to_1(%arg0: i42) {
   "test.return"(%arg0) : (i42) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_input_to_self
 func.func @remap_input_to_self(%arg0: index) {
   // CHECK-NOT: test.cast
@@ -68,6 +117,8 @@ func.func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
  "test.invalid"(%arg0, %arg1) : (i64, i64) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @no_remap_nested
 func.func @no_remap_nested() {
   // CHECK-NEXT: "foo.region"
@@ -82,6 +133,8 @@ func.func @no_remap_nested() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_moved_region_args
 func.func @remap_moved_region_args() {
   // CHECK-NEXT: return
@@ -96,6 +149,8 @@ func.func @remap_moved_region_args() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_cloned_region_args
 func.func @remap_cloned_region_args() {
   // CHECK-NEXT: return
@@ -122,6 +177,8 @@ func.func @remap_drop_region() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @dropped_input_in_use
 func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
   // CHECK-NEXT: "test.cast"{{.*}} : () -> i16
@@ -130,6 +187,8 @@ func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
   "work"(%arg) : (i16) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @up_to_date_replacement
 func.func @up_to_date_replacement(%arg: i8) -> i8 {
   // CHECK-NEXT: return
@@ -139,6 +198,8 @@ func.func @up_to_date_replacement(%arg: i8) -> i8 {
   return %repl_2 : i8
 }
 
+// -----
+
 // CHECK-LABEL: func @remove_foldable_op
 // CHECK-SAME:                          (%[[ARG_0:[a-z0-9]*]]: i32)
 func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
@@ -150,6 +211,8 @@ func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
   return %0 : i32
 }
 
+// -----
+
 // CHECK-LABEL: @create_block
 func.func @create_block() {
   // Check that we created a block with arguments.
@@ -161,6 +224,12 @@ func.func @create_block() {
   return
 }
 
+// -----
+
+//      CHECK: notifyOperationModified: test.recursive_rewrite
+// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite
+// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite
+
 // CHECK-LABEL: @bounded_recursion
 func.func @bounded_recursion() {
   // CHECK: test.recursive_rewrite 0
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 27eae2ffd694b5a..2da184bc3d85baa 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -327,8 +327,12 @@ struct TestPatternDriver
 struct DumpNotifications : public RewriterBase::Listener {
   void notifyBlockInserted(Block *block, Region *previous,
                            Region::iterator previousIt) override {
-    llvm::outs() << "notifyBlockInserted into "
-                 << block->getParentOp()->getName() << ": ";
+    llvm::outs() << "notifyBlockInserted";
+    if (block->getParentOp()) {
+      llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
+    } else {
+      llvm::outs() << " into unknown op: ";
+    }
     if (previous == nullptr) {
       llvm::outs() << "was unlinked\n";
     } else {
@@ -341,7 +345,9 @@ struct DumpNotifications : public RewriterBase::Listener {
     if (!previous.isSet()) {
       llvm::outs() << ", was unlinked\n";
     } else {
-      if (previous.getPoint() == previous.getBlock()->end()) {
+      if (!previous.getPoint().getNodePtr()) {
+        llvm::outs() << ", was linked, exact position unknown\n";
+      } else if (previous.getPoint() == previous.getBlock()->end()) {
         llvm::outs() << ", was last in block\n";
       } else {
         llvm::outs() << ", previous = " << previous.getPoint()->getName()
@@ -349,9 +355,18 @@ struct DumpNotifications : public RewriterBase::Listener {
       }
     }
   }
+  void notifyBlockErased(Block *block) override {
+    llvm::outs() << "notifyBlockErased\n";
+  }
   void notifyOperationErased(Operation *op) override {
     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
   }
+  void notifyOperationModified(Operation *op) override {
+    llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
+  }
+  void notifyOperationReplaced(Operation *op, ValueRange values) override {
+    llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
+  }
 };
 
 struct TestStrictPatternDriver
@@ -1153,6 +1168,8 @@ struct TestLegalizePatternDriver
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;
       ConversionConfig config;
+      DumpNotifications dumpNotifications;
+      config.listener = &dumpNotifications;
       config.unlegalizedOps = &unlegalizedOps;
       if (failed(applyPartialConversion(getOperation(), target,
                                         std::move(patterns), config))) {
@@ -1171,8 +1188,11 @@ struct TestLegalizePatternDriver
         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
       });
 
+      ConversionConfig config;
+      DumpNotifications dumpNotifications;
+      config.listener = &dumpNotifications;
       if (failed(applyFullConversion(getOperation(), target,
-                                     std::move(patterns)))) {
+                                     std::move(patterns), config))) {
         getOperation()->emitRemark() << "applyFullConversion failed";
       }
       return;



More information about the Mlir-commits mailing list