[Mlir-commits] [mlir] 60a20bd - [mlir][Transforms] Add listener support to dialect conversion (#83425)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Thu Mar 7 17:34:49 PST 2024


Author: Matthias Springer
Date: 2024-03-08T10:34:45+09:00
New Revision: 60a20bd6973c8fc7aa9a19465ed042604e07fb17

URL: https://github.com/llvm/llvm-project/commit/60a20bd6973c8fc7aa9a19465ed042604e07fb17
DIFF: https://github.com/llvm/llvm-project/commit/60a20bd6973c8fc7aa9a19465ed042604e07fb17.diff

LOG: [mlir][Transforms] Add listener support to dialect conversion (#83425)

This commit adds listener support to the dialect conversion. Similarly
to the greedy pattern rewrite driver, an optional listener can be
specified in the configuration object.

Listeners are notified only if the dialect conversion succeeds. In case
of a failure, where some IR changes are first performed and then rolled
back, no notifications are sent.

Due to the fact that some kinds of rewrite are reflected in the IR
immediately and some in a delayed fashion, there are certain limitations
when attaching a listener; these are documented in `ConversionConfig`.
To summarize, users are always notified about all rewrites that
happened, but the notifications are sent all at once at the very end,
and not interleaved with the actual IR changes.

This change is in preparation improvements to
`transform.apply_conversion_patterns`, which currently invalidates all
handles. In the future, it can use a listener to update handles
accordingly, similar to `transform.apply_patterns`.

Added: 
    

Modified: 
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/Transforms/Utils/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 01fde101ef3cb6..83198c9b0db545 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1085,6 +1085,39 @@ struct ConversionConfig {
   /// IR during an analysis conversion and only pre-existing operations are
   /// added to the set.
   DenseSet<Operation *> *legalizableOps = nullptr;
+
+  /// An optional listener that is notified about all IR modifications in case
+  /// dialect conversion succeeds. If the dialect conversion fails and no IR
+  /// modifications are visible (i.e., they were all rolled back), no
+  /// notifications are sent.
+  ///
+  /// Note: Notifications are sent in a delayed fashion, when the dialect
+  /// conversion is guaranteed to succeed. At that point, some IR modifications
+  /// may already have been materialized. Consequently, operations/blocks that
+  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
+  /// guaranteed to be valid pointers and accessing op names is allowed. But
+  /// there are no guarantees about the state of ops/blocks at the time that a
+  /// callback is triggered.)
+  ///
+  /// Example: Consider a dialect conversion a new op ("test.foo") is created
+  /// and inserted, and later moved to another block. (Moving ops also triggers
+  /// "notifyOperationInserted".)
+  ///
+  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
+  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
+  ///
+  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
+  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
+  /// immediately performed by the dialect conversion (and rolled back upon
+  /// failure).
+  //
+  // Note: When receiving a "notifyBlockInserted"/"notifyOperationInserted"
+  // callback, the previous region/block is provided to the callback, but not
+  // the iterator pointing to the exact location within the region/block. That
+  // is because these notifications are sent with a delay (after the IR has
+  // already been modified) and iterators into past IR state cannot be
+  // represented at the moment.
+  RewriterBase::Listener *listener = nullptr;
 };
 
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 8b2d71408a5651..c1a261eab8487d 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -204,14 +204,22 @@ class IRRewrite {
   /// Roll back the rewrite. Operations may be erased during rollback.
   virtual void rollback() = 0;
 
-  /// Commit the rewrite. Operations/blocks may be unlinked during the commit
-  /// phase, but they must not be erased yet. This is because internal dialect
-  /// conversion state (such as `mapping`) may still be using them. Operations/
-  /// blocks must be erased during cleanup.
-  virtual void commit() {}
+  /// Commit the rewrite. At this point, it is certain that the dialect
+  /// conversion will succeed. All IR modifications, except for operation/block
+  /// erasure, must be performed through the given rewriter.
+  ///
+  /// Instead of erasing operations/blocks, they should merely be unlinked
+  /// commit phase and finally be erased during the cleanup phase. This is
+  /// because internal dialect conversion state (such as `mapping`) may still
+  /// be using them.
+  ///
+  /// Any IR modification that was already performed before the commit phase
+  /// (e.g., insertion of an op) must be communicated to the listener that may
+  /// be attached to the given rewriter.
+  virtual void commit(RewriterBase &rewriter) {}
 
   /// Cleanup operations/blocks. Cleanup is called after commit.
-  virtual void cleanup() {}
+  virtual void cleanup(RewriterBase &rewriter) {}
 
   Kind getKind() const { return kind; }
 
@@ -221,12 +229,6 @@ class IRRewrite {
   IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
       : kind(kind), rewriterImpl(rewriterImpl) {}
 
-  /// Erase the given op (unless it was already erased).
-  void eraseOp(Operation *op);
-
-  /// Erase the given block (unless it was already erased).
-  void eraseBlock(Block *block);
-
   const ConversionConfig &getConfig() const;
 
   const Kind kind;
@@ -265,6 +267,12 @@ class CreateBlockRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::CreateBlock;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The block was already created and inserted. Just inform the listener.
+    if (auto *listener = rewriter.getListener())
+      listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
+  }
+
   void rollback() override {
     // Unlink all of the operations within this block, they will be deleted
     // separately.
@@ -311,10 +319,19 @@ class EraseBlockRewrite : public BlockRewrite {
     block = nullptr;
   }
 
-  void cleanup() override {
+  void commit(RewriterBase &rewriter) override {
     // Erase the block.
     assert(block && "expected block");
     assert(block->empty() && "expected empty block");
+
+    // Notify the listener that the block is about to be erased.
+    if (auto *listener =
+            dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+      listener->notifyBlockErased(block);
+  }
+
+  void cleanup(RewriterBase &rewriter) override {
+    // Erase the block.
     block->dropAllDefinedValueUses();
     delete block;
     block = nullptr;
@@ -341,6 +358,13 @@ class InlineBlockRewrite : public BlockRewrite {
         firstInlinedInst(sourceBlock->empty() ? nullptr
                                               : &sourceBlock->front()),
         lastInlinedInst(sourceBlock->empty() ? nullptr : &sourceBlock->back()) {
+    // If a listener is attached to the dialect conversion, ops must be moved
+    // one-by-one. When they are moved in bulk, notifications cannot be sent
+    // because the ops that used to be in the source block at the time of the
+    // inlining (before the "commit" phase) are unknown at the time when
+    // notifications are sent (which is during the "commit" phase).
+    assert(!getConfig().listener &&
+           "InlineBlockRewrite not supported if listener is attached");
   }
 
   static bool classof(const IRRewrite *rewrite) {
@@ -382,6 +406,16 @@ class MoveBlockRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::MoveBlock;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The block was already moved. Just inform the listener.
+    if (auto *listener = rewriter.getListener()) {
+      // Note: `previousIt` cannot be passed because this is a delayed
+      // notification and iterators into past IR state cannot be represented.
+      listener->notifyBlockInserted(block, /*previous=*/region,
+                                    /*previousIt=*/{});
+    }
+  }
+
   void rollback() override {
     // Move the block back to its original position.
     Region::iterator before =
@@ -437,7 +471,7 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   LogicalResult
   materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
@@ -466,7 +500,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::ReplaceBlockArg;
   }
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
@@ -506,6 +540,17 @@ class MoveOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::MoveOperation;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The operation was already moved. Just inform the listener.
+    if (auto *listener = rewriter.getListener()) {
+      // Note: `previousIt` cannot be passed because this is a delayed
+      // notification and iterators into past IR state cannot be represented.
+      listener->notifyOperationInserted(
+          op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
+                                                  /*insertPt=*/{}));
+    }
+  }
+
   void rollback() override {
     // Move the operation back to its original position.
     Block::iterator before =
@@ -549,7 +594,12 @@ class ModifyOperationRewrite : public OperationRewrite {
            "rewrite was neither committed nor rolled back");
   }
 
-  void commit() override {
+  void commit(RewriterBase &rewriter) override {
+    // Notify the listener that the operation was modified in-place.
+    if (auto *listener =
+            dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+      listener->notifyOperationModified(op);
+
     if (propertiesStorage) {
       OpaqueProperties propCopy(propertiesStorage);
       // Note: The operation may have been erased in the mean time, so
@@ -600,11 +650,11 @@ class ReplaceOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::ReplaceOperation;
   }
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
-  void cleanup() override;
+  void cleanup(RewriterBase &rewriter) override;
 
   const TypeConverter *getConverter() const { return converter; }
 
@@ -629,6 +679,12 @@ class CreateOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::CreateOperation;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The operation was already created and inserted. Just inform the listener.
+    if (auto *listener = rewriter.getListener())
+      listener->notifyOperationInserted(op, /*previous=*/{});
+  }
+
   void rollback() override;
 };
 
@@ -666,7 +722,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 
   void rollback() override;
 
-  void cleanup() override;
+  void cleanup(RewriterBase &rewriter) override;
 
   /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
@@ -735,7 +791,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : eraseRewriter(ctx), config(config) {}
+      : context(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -900,6 +956,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
     }
 
     void notifyOperationErased(Operation *op) override { erased.insert(op); }
+
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
     /// Pointers to all erased operations and blocks.
@@ -910,8 +967,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // State
   //===--------------------------------------------------------------------===//
 
-  /// This rewriter must be used for erasing ops/blocks.
-  SingleEraseRewriter eraseRewriter;
+  /// MLIR context.
+  MLIRContext *context;
 
   // Mapping between replaced values that 
diff er in type. This happens when
   // replacing a value with one of a 
diff erent type.
@@ -955,19 +1012,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 } // namespace detail
 } // namespace mlir
 
-void IRRewrite::eraseOp(Operation *op) {
-  rewriterImpl.eraseRewriter.eraseOp(op);
-}
-
-void IRRewrite::eraseBlock(Block *block) {
-  rewriterImpl.eraseRewriter.eraseBlock(block);
-}
-
 const ConversionConfig &IRRewrite::getConfig() const {
   return rewriterImpl.config;
 }
 
-void BlockTypeConversionRewrite::commit() {
+void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
+  // Inform the listener about all IR modifications that have already taken
+  // place: References to the original block have been replaced with the new
+  // block.
+  if (auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+          rewriter.getListener()))
+    for (Operation *op : block->getUsers())
+      listener->notifyOperationModified(op);
+
   // Process the remapping for each of the original arguments.
   for (auto [origArg, info] :
        llvm::zip_equal(origBlock->getArguments(), argInfo)) {
@@ -975,7 +1032,7 @@ void BlockTypeConversionRewrite::commit() {
     if (!info) {
       if (Value newArg =
               rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-        origArg.replaceAllUsesWith(newArg);
+        rewriter.replaceAllUsesWith(origArg, newArg);
       continue;
     }
 
@@ -985,8 +1042,8 @@ void BlockTypeConversionRewrite::commit() {
 
     // If the argument is still used, replace it with the generated cast.
     if (!origArg.use_empty()) {
-      origArg.replaceAllUsesWith(
-          rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
+      rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
+                                               castValue, origArg.getType()));
     }
   }
 }
@@ -1042,13 +1099,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
   return success();
 }
 
-void ReplaceBlockArgRewrite::commit() {
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
   if (!repl)
     return;
 
   if (isa<BlockArgument>(repl)) {
-    arg.replaceAllUsesWith(repl);
+    rewriter.replaceAllUsesWith(arg, repl);
     return;
   }
 
@@ -1057,7 +1114,7 @@ void ReplaceBlockArgRewrite::commit() {
   // replacement value.
   Operation *replOp = cast<OpResult>(repl).getOwner();
   Block *replBlock = replOp->getBlock();
-  arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
   });
@@ -1065,14 +1122,40 @@ void ReplaceBlockArgRewrite::commit() {
 
 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
 
-void ReplaceOperationRewrite::commit() {
-  for (OpResult result : op->getResults())
-    if (Value newValue =
-            rewriterImpl.mapping.lookupOrNull(result, result.getType()))
-      result.replaceAllUsesWith(newValue);
+void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
+  auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+      rewriter.getListener());
+
+  // Compute replacement values.
+  SmallVector<Value> replacements =
+      llvm::map_to_vector(op->getResults(), [&](OpResult result) {
+        return rewriterImpl.mapping.lookupOrNull(result, result.getType());
+      });
+
+  // Notify the listener that the operation is about to be replaced.
+  if (listener)
+    listener->notifyOperationReplaced(op, replacements);
+
+  // Replace all uses with the new values.
+  for (auto [result, newValue] :
+       llvm::zip_equal(op->getResults(), replacements))
+    if (newValue)
+      rewriter.replaceAllUsesWith(result, newValue);
+
+  // The original op will be erased, so remove it from the set of unlegalized
+  // ops.
   if (getConfig().unlegalizedOps)
     getConfig().unlegalizedOps->erase(op);
+
+  // Notify the listener that the operation (and its nested operations) was
+  // erased.
+  if (listener) {
+    op->walk<WalkOrder::PostOrder>(
+        [&](Operation *op) { listener->notifyOperationErased(op); });
+  }
+
   // Do not erase the operation yet. It may still be referenced in `mapping`.
+  // Just unlink it for now and erase it during cleanup.
   op->getBlock()->getOperations().remove(op);
 }
 
@@ -1081,7 +1164,9 @@ void ReplaceOperationRewrite::rollback() {
     rewriterImpl.mapping.erase(result);
 }
 
-void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
+  rewriter.eraseOp(op);
+}
 
 void CreateOperationRewrite::rollback() {
   for (Region &region : op->getRegions()) {
@@ -1100,14 +1185,20 @@ void UnresolvedMaterializationRewrite::rollback() {
   op->erase();
 }
 
-void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
+void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
+  rewriter.eraseOp(op);
+}
 
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
+  IRRewriter rewriter(context, config.listener);
   for (auto &rewrite : rewrites)
-    rewrite->commit();
+    rewrite->commit(rewriter);
+
+  // Clean up all rewrites.
+  SingleEraseRewriter eraseRewriter(context);
   for (auto &rewrite : rewrites)
-    rewrite->cleanup();
+    rewrite->cleanup(eraseRewriter);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1306,8 +1397,21 @@ Block *ConversionPatternRewriterImpl::applySignatureConversion(
   Block *newBlock =
       rewriter.createBlock(block->getParent(), std::next(block->getIterator()),
                            convertedTypes, newLocs);
-  appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
-  newBlock->getOperations().splice(newBlock->end(), block->getOperations());
+
+  // If a listener is attached to the dialect conversion, ops cannot be moved
+  // to the destination block in bulk ("fast path"). This is because at the time
+  // the notifications are sent, it is unknown which ops were moved. Instead,
+  // ops should be moved one-by-one ("slow path"), so that a separate
+  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
+  // a bit more efficient, so we try to do that when possible.
+  bool fastPath = !config.listener;
+  if (fastPath) {
+    appendRewrite<InlineBlockRewrite>(newBlock, block, newBlock->end());
+    newBlock->getOperations().splice(newBlock->end(), block->getOperations());
+  } else {
+    while (!block->empty())
+      rewriter.moveOpBefore(&block->front(), newBlock, newBlock->end());
+  }
 
   // Replace all uses of the old block with the new block.
   block->replaceAllUsesWith(newBlock);
@@ -1645,10 +1749,31 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
          "expected 'source' to have no predecessors");
 #endif // NDEBUG
 
-  impl->notifyBlockBeingInlined(dest, source, before);
+  // If a listener is attached to the dialect conversion, ops cannot be moved
+  // to the destination block in bulk ("fast path"). This is because at the time
+  // the notifications are sent, it is unknown which ops were moved. Instead,
+  // ops should be moved one-by-one ("slow path"), so that a separate
+  // `MoveOperationRewrite` is enqueued for each moved op. Moving ops in bulk is
+  // a bit more efficient, so we try to do that when possible.
+  bool fastPath = !impl->config.listener;
+
+  if (fastPath)
+    impl->notifyBlockBeingInlined(dest, source, before);
+
+  // Replace all uses of block arguments.
   for (auto it : llvm::zip(source->getArguments(), argValues))
     replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
-  dest->getOperations().splice(before, source->getOperations());
+
+  if (fastPath) {
+    // Move all ops at once.
+    dest->getOperations().splice(before, source->getOperations());
+  } else {
+    // Move op by op.
+    while (!source->empty())
+      moveOpBefore(&source->front(), dest, before);
+  }
+
+  // Erase the source block.
   eraseBlock(source);
 }
 

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index ccdc9fe78ea0d3..d552f0346644b3 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -1,5 +1,10 @@
 // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics %s | FileCheck %s
 
+//      CHECK: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_a
+
 // CHECK-LABEL: verifyDirectPattern
 func.func @verifyDirectPattern() -> i32 {
   // CHECK-NEXT:  "test.legal_op_a"() <{status = "Success"}
@@ -8,6 +13,16 @@ func.func @verifyDirectPattern() -> i32 {
   return %result : i32
 }
 
+// -----
+
+//      CHECK: notifyOperationInserted: test.illegal_op_e, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_c
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_c
+// CHECK-NEXT: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_e
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_e
+
 // CHECK-LABEL: verifyLargerBenefit
 func.func @verifyLargerBenefit() -> i32 {
   // CHECK-NEXT:  "test.legal_op_a"() <{status = "Success"}
@@ -16,16 +31,24 @@ func.func @verifyLargerBenefit() -> i32 {
   return %result : i32
 }
 
+// -----
+
+// CHECK: notifyOperationModified: func.func
+// Note: No block insertion because this function is external and no block
+// signature conversion is performed.
+
 // CHECK-LABEL: func private @remap_input_1_to_0()
 func.func private @remap_input_1_to_0(i16)
 
+// -----
+
 // CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64)
 func.func @remap_input_1_to_1(%arg0: i64) {
   // CHECK-NEXT: "test.valid"{{.*}} : (f64)
   "test.invalid"(%arg0) : (i64) -> ()
 }
 
-// CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64)
+// CHECK: func @remap_call_1_to_1(%arg0: f64)
 func.func @remap_call_1_to_1(%arg0: i64) {
   // CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> ()
   call @remap_input_1_to_1(%arg0) : (i64) -> ()
@@ -33,12 +56,36 @@ func.func @remap_call_1_to_1(%arg0: i64) {
   return
 }
 
+// -----
+
+// Block signature conversion: new block is inserted.
+// CHECK:      notifyBlockInserted into func.func: was unlinked
+
+// Contents of the old block are moved to the new block.
+// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
+
+// The new block arguments are used in "test.return".
+// CHECK-NEXT: notifyOperationModified: test.return
+
+// The old block is erased.
+// CHECK-NEXT: notifyBlockErased
+
+// The function op gets a new type attribute.
+// CHECK-NEXT: notifyOperationModified: func.func
+
+// "test.return" is replaced.
+// CHECK-NEXT: notifyOperationInserted: test.return, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.return
+// CHECK-NEXT: notifyOperationErased: test.return
+
 // CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16)
 func.func @remap_input_1_to_N(%arg0: f32) -> f32 {
   // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> ()
   "test.return"(%arg0) : (f32) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16)
 func.func @remap_input_1_to_N_remaining_use(%arg0: f32) {
   // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32
@@ -54,6 +101,8 @@ func.func @remap_materialize_1_to_1(%arg0: i42) {
   "test.return"(%arg0) : (i42) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_input_to_self
 func.func @remap_input_to_self(%arg0: index) {
   // CHECK-NOT: test.cast
@@ -68,6 +117,8 @@ func.func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
  "test.invalid"(%arg0, %arg1) : (i64, i64) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @no_remap_nested
 func.func @no_remap_nested() {
   // CHECK-NEXT: "foo.region"
@@ -82,6 +133,8 @@ func.func @no_remap_nested() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_moved_region_args
 func.func @remap_moved_region_args() {
   // CHECK-NEXT: return
@@ -96,6 +149,8 @@ func.func @remap_moved_region_args() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_cloned_region_args
 func.func @remap_cloned_region_args() {
   // CHECK-NEXT: return
@@ -122,6 +177,8 @@ func.func @remap_drop_region() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @dropped_input_in_use
 func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
   // CHECK-NEXT: "test.cast"{{.*}} : () -> i16
@@ -130,6 +187,8 @@ func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
   "work"(%arg) : (i16) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @up_to_date_replacement
 func.func @up_to_date_replacement(%arg: i8) -> i8 {
   // CHECK-NEXT: return
@@ -139,6 +198,8 @@ func.func @up_to_date_replacement(%arg: i8) -> i8 {
   return %repl_2 : i8
 }
 
+// -----
+
 // CHECK-LABEL: func @remove_foldable_op
 // CHECK-SAME:                          (%[[ARG_0:[a-z0-9]*]]: i32)
 func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
@@ -150,6 +211,8 @@ func.func @remove_foldable_op(%arg0 : i32) -> (i32) {
   return %0 : i32
 }
 
+// -----
+
 // CHECK-LABEL: @create_block
 func.func @create_block() {
   // Check that we created a block with arguments.
@@ -161,6 +224,12 @@ func.func @create_block() {
   return
 }
 
+// -----
+
+//      CHECK: notifyOperationModified: test.recursive_rewrite
+// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite
+// CHECK-NEXT: notifyOperationModified: test.recursive_rewrite
+
 // CHECK-LABEL: @bounded_recursion
 func.func @bounded_recursion() {
   // CHECK: test.recursive_rewrite 0

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 27eae2ffd694b5..2da184bc3d85ba 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -327,8 +327,12 @@ struct TestPatternDriver
 struct DumpNotifications : public RewriterBase::Listener {
   void notifyBlockInserted(Block *block, Region *previous,
                            Region::iterator previousIt) override {
-    llvm::outs() << "notifyBlockInserted into "
-                 << block->getParentOp()->getName() << ": ";
+    llvm::outs() << "notifyBlockInserted";
+    if (block->getParentOp()) {
+      llvm::outs() << " into " << block->getParentOp()->getName() << ": ";
+    } else {
+      llvm::outs() << " into unknown op: ";
+    }
     if (previous == nullptr) {
       llvm::outs() << "was unlinked\n";
     } else {
@@ -341,7 +345,9 @@ struct DumpNotifications : public RewriterBase::Listener {
     if (!previous.isSet()) {
       llvm::outs() << ", was unlinked\n";
     } else {
-      if (previous.getPoint() == previous.getBlock()->end()) {
+      if (!previous.getPoint().getNodePtr()) {
+        llvm::outs() << ", was linked, exact position unknown\n";
+      } else if (previous.getPoint() == previous.getBlock()->end()) {
         llvm::outs() << ", was last in block\n";
       } else {
         llvm::outs() << ", previous = " << previous.getPoint()->getName()
@@ -349,9 +355,18 @@ struct DumpNotifications : public RewriterBase::Listener {
       }
     }
   }
+  void notifyBlockErased(Block *block) override {
+    llvm::outs() << "notifyBlockErased\n";
+  }
   void notifyOperationErased(Operation *op) override {
     llvm::outs() << "notifyOperationErased: " << op->getName() << "\n";
   }
+  void notifyOperationModified(Operation *op) override {
+    llvm::outs() << "notifyOperationModified: " << op->getName() << "\n";
+  }
+  void notifyOperationReplaced(Operation *op, ValueRange values) override {
+    llvm::outs() << "notifyOperationReplaced: " << op->getName() << "\n";
+  }
 };
 
 struct TestStrictPatternDriver
@@ -1153,6 +1168,8 @@ struct TestLegalizePatternDriver
     if (mode == ConversionMode::Partial) {
       DenseSet<Operation *> unlegalizedOps;
       ConversionConfig config;
+      DumpNotifications dumpNotifications;
+      config.listener = &dumpNotifications;
       config.unlegalizedOps = &unlegalizedOps;
       if (failed(applyPartialConversion(getOperation(), target,
                                         std::move(patterns), config))) {
@@ -1171,8 +1188,11 @@ struct TestLegalizePatternDriver
         return (bool)op->getAttrOfType<UnitAttr>("test.dynamically_legal");
       });
 
+      ConversionConfig config;
+      DumpNotifications dumpNotifications;
+      config.listener = &dumpNotifications;
       if (failed(applyFullConversion(getOperation(), target,
-                                     std::move(patterns)))) {
+                                     std::move(patterns), config))) {
         getOperation()->emitRemark() << "applyFullConversion failed";
       }
       return;


        


More information about the Mlir-commits mailing list