[llvm-branch-commits] [mlir] [mlir][Transforms] Add listener support to dialect conversion (PR #83425)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Thu Feb 29 05:07:13 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-core

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

This commit adds listener support to the dialect conversion. Similarly to the greedy pattern rewrite driver, an optional listener can be specified in the configuration object.

Listeners are notified only if the dialect conversion succeeds. In case of a failure, where some IR changes are first performed and then rolled back, no notifications are sent.

Due to the fact that some kinds of rewrite are reflected in the IR immediately and some in a delayed fashion, there are certain limitations when attaching a listener; these are documented in `ConversionConfig`. To summarize, users are always notified about all rewrites that happened, but the notifications are sent all at once at the very end, and not interleaved with the actual IR changes.

This change is in preparation improvements to `transform.apply_conversion_patterns`, which currently invalidates all handles. In the future, it can use a listener to update handles accordingly, similar  to `transform.apply_patterns`.


---

Patch is 24.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/83425.diff


4 Files Affected:

- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+26) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+128-46) 
- (modified) mlir/test/Transforms/test-legalizer.mlir (+70-1) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+24-4) 


``````````diff
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 84396529eb7c2e..98944b8a1ea648 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -1091,6 +1091,32 @@ struct ConversionConfig {
   /// IR during an analysis conversion and only pre-existing operations are
   /// added to the set.
   DenseSet<Operation *> *legalizableOps = nullptr;
+
+  /// An optional listener that is notified about all IR modifications in case
+  /// dialect conversion succeeds. If the dialect conversion fails and no IR
+  /// modifications are visible (i.e., they were all rolled back), no
+  /// notifications are sent.
+  ///
+  /// Note: Notifications are sent in a delayed fashion, when the dialect
+  /// conversion is guaranteed to succeed. At that point, some IR modifications
+  /// may already have been materialized. Consequently, operations/blocks that
+  /// are passed to listener callbacks should not be accessed. (Ops/blocks are
+  /// guaranteed to be valid pointers and accessing op names is allowed. But
+  /// there are no guarantees about the state of ops/blocks at the time that a
+  /// callback is triggered.)
+  ///
+  /// Example: Consider a dialect conversion a new op ("test.foo") is created
+  /// and inserted, and later moved to another block. (Moving ops also triggers
+  /// "notifyOperationInserted".)
+  ///
+  /// (1) notifyOperationInserted: "test.foo" (into block "b1")
+  /// (2) notifyOperationInserted: "test.foo" (moved to another block "b2")
+  ///
+  /// When querying "op->getBlock()" during the first "notifyOperationInserted",
+  /// "b2" would be returned because "moving an op" is a kind of rewrite that is
+  /// immediately performed by the dialect conversion (and rolled back upon
+  /// failure).
+  RewriterBase::Listener *listener = nullptr;
 };
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 9f6468402686bd..0048347cae9314 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -203,14 +203,22 @@ class IRRewrite {
   /// Roll back the rewrite. Operations may be erased during rollback.
   virtual void rollback() = 0;
 
-  /// Commit the rewrite. Operations may be unlinked from their blocks during
-  /// the commit phase, but they must not be erased yet. This is because
-  /// internal dialect conversion state (such as `mapping`) may still be using
-  /// them. Operations must be erased during cleanup.
-  virtual void commit() {}
+  /// Commit the rewrite. At this point, it is certain that the dialect
+  /// conversion will succeed. All IR modifications, except for operation
+  /// erasure, must be performed through the given rewriter.
+  ///
+  /// Instead of erasing operations, they should merely be unlinked from their
+  /// blocks during the commit phase and finally be erased during the cleanup
+  /// phase. This is because internal dialect conversion state (such as
+  /// `mapping`) may still be using them.
+  ///
+  /// Any IR modification that was already performed before the commit phase
+  /// (e.g., insertion of an op) must be communicated to the listener that may
+  /// be attached to the given rewriter.
+  virtual void commit(RewriterBase &rewriter) {}
 
   /// Cleanup operations. Cleanup is called after commit.
-  virtual void cleanup() {}
+  virtual void cleanup(RewriterBase &rewriter) {}
 
   Kind getKind() const { return kind; }
 
@@ -220,12 +228,6 @@ class IRRewrite {
   IRRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl)
       : kind(kind), rewriterImpl(rewriterImpl) {}
 
-  /// Erase the given op (unless it was already erased).
-  void eraseOp(Operation *op);
-
-  /// Erase the given block (unless it was already erased).
-  void eraseBlock(Block *block);
-
   const ConversionConfig &getConfig() const;
 
   const Kind kind;
@@ -264,6 +266,12 @@ class CreateBlockRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::CreateBlock;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The block was already created and inserted. Just inform the listener.
+    if (auto *listener = rewriter.getListener())
+      listener->notifyBlockInserted(block, /*previous=*/{}, /*previousIt=*/{});
+  }
+
   void rollback() override {
     // Unlink all of the operations within this block, they will be deleted
     // separately.
@@ -309,10 +317,17 @@ class EraseBlockRewrite : public BlockRewrite {
     block = nullptr;
   }
 
-  void commit() override {
+  void commit(RewriterBase &rewriter) override {
     // Erase the block.
     assert(block && "expected block");
     assert(block->empty() && "expected empty block");
+
+    // Notify the listener that the block is about to be erased.
+    if (auto *listener =
+            dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+      listener->notifyBlockErased(block);
+
+    // Erase the block.
     block->dropAllDefinedValueUses();
     delete block;
     block = nullptr;
@@ -339,6 +354,16 @@ class MoveBlockRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::MoveBlock;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The block was already moved. Just inform the listener.
+    if (auto *listener = rewriter.getListener()) {
+      // Note: `previousIt` cannot be passed because this is a delayed
+      // notification and iterators into past IR state cannot be represented.
+      listener->notifyBlockInserted(block, /*previous=*/region,
+                                    /*previousIt=*/{});
+    }
+  }
+
   void rollback() override {
     // Move the block back to its original position.
     Region::iterator before =
@@ -394,7 +419,7 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   LogicalResult
   materializeLiveConversions(function_ref<Operation *(Value)> findLiveUser);
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
@@ -423,7 +448,7 @@ class ReplaceBlockArgRewrite : public BlockRewrite {
     return rewrite->getKind() == Kind::ReplaceBlockArg;
   }
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
@@ -463,6 +488,17 @@ class MoveOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::MoveOperation;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The operation was already moved. Just inform the listener.
+    if (auto *listener = rewriter.getListener()) {
+      // Note: `previousIt` cannot be passed because this is a delayed
+      // notification and iterators into past IR state cannot be represented.
+      listener->notifyOperationInserted(
+          op, /*previous=*/OpBuilder::InsertPoint(/*insertBlock=*/block,
+                                                  /*insertPt=*/{}));
+    }
+  }
+
   void rollback() override {
     // Move the operation back to its original position.
     Block::iterator before =
@@ -506,7 +542,12 @@ class ModifyOperationRewrite : public OperationRewrite {
            "rewrite was neither committed nor rolled back");
   }
 
-  void commit() override {
+  void commit(RewriterBase &rewriter) override {
+    // Notify the listener that the operation was modified in-place.
+    if (auto *listener =
+            dyn_cast_or_null<RewriterBase::Listener>(rewriter.getListener()))
+      listener->notifyOperationModified(op);
+
     if (propertiesStorage) {
       OpaqueProperties propCopy(propertiesStorage);
       // Note: The operation may have been erased in the mean time, so
@@ -557,11 +598,11 @@ class ReplaceOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::ReplaceOperation;
   }
 
-  void commit() override;
+  void commit(RewriterBase &rewriter) override;
 
   void rollback() override;
 
-  void cleanup() override;
+  void cleanup(RewriterBase &rewriter) override;
 
   const TypeConverter *getConverter() const { return converter; }
 
@@ -586,6 +627,12 @@ class CreateOperationRewrite : public OperationRewrite {
     return rewrite->getKind() == Kind::CreateOperation;
   }
 
+  void commit(RewriterBase &rewriter) override {
+    // The operation was already created and inserted. Just inform the listener.
+    if (auto *listener = rewriter.getListener())
+      listener->notifyOperationInserted(op, /*previous=*/{});
+  }
+
   void rollback() override;
 };
 
@@ -623,7 +670,7 @@ class UnresolvedMaterializationRewrite : public OperationRewrite {
 
   void rollback() override;
 
-  void cleanup() override;
+  void cleanup(RewriterBase &rewriter) override;
 
   /// Return the type converter of this materialization (which may be null).
   const TypeConverter *getConverter() const {
@@ -692,7 +739,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : eraseRewriter(ctx), config(config) {}
+      : context(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -853,6 +900,7 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
     }
 
     void notifyOperationErased(Operation *op) override { erased.insert(op); }
+
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
     /// Pointers to all erased operations and blocks.
@@ -863,8 +911,8 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   // State
   //===--------------------------------------------------------------------===//
 
-  /// This rewriter must be used for erasing ops/blocks.
-  SingleEraseRewriter eraseRewriter;
+  /// MLIR context.
+  MLIRContext *context;
 
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
@@ -908,19 +956,19 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 } // namespace detail
 } // namespace mlir
 
-void IRRewrite::eraseOp(Operation *op) {
-  rewriterImpl.eraseRewriter.eraseOp(op);
-}
-
-void IRRewrite::eraseBlock(Block *block) {
-  rewriterImpl.eraseRewriter.eraseBlock(block);
-}
-
 const ConversionConfig &IRRewrite::getConfig() const {
   return rewriterImpl.config;
 }
 
-void BlockTypeConversionRewrite::commit() {
+void BlockTypeConversionRewrite::commit(RewriterBase &rewriter) {
+  // Inform the listener about all IR modifications that have already taken
+  // place: References to the original block have been replaced with the new
+  // block.
+  if (auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+          rewriter.getListener()))
+    for (Operation *op : block->getUsers())
+      listener->notifyOperationModified(op);
+
   // Process the remapping for each of the original arguments.
   for (auto [origArg, info] :
        llvm::zip_equal(origBlock->getArguments(), argInfo)) {
@@ -928,7 +976,7 @@ void BlockTypeConversionRewrite::commit() {
     if (!info) {
       if (Value newArg =
               rewriterImpl.mapping.lookupOrNull(origArg, origArg.getType()))
-        origArg.replaceAllUsesWith(newArg);
+        rewriter.replaceAllUsesWith(origArg, newArg);
       continue;
     }
 
@@ -938,8 +986,8 @@ void BlockTypeConversionRewrite::commit() {
 
     // If the argument is still used, replace it with the generated cast.
     if (!origArg.use_empty()) {
-      origArg.replaceAllUsesWith(
-          rewriterImpl.mapping.lookupOrDefault(castValue, origArg.getType()));
+      rewriter.replaceAllUsesWith(origArg, rewriterImpl.mapping.lookupOrDefault(
+                                               castValue, origArg.getType()));
     }
   }
 }
@@ -995,13 +1043,13 @@ LogicalResult BlockTypeConversionRewrite::materializeLiveConversions(
   return success();
 }
 
-void ReplaceBlockArgRewrite::commit() {
+void ReplaceBlockArgRewrite::commit(RewriterBase &rewriter) {
   Value repl = rewriterImpl.mapping.lookupOrNull(arg, arg.getType());
   if (!repl)
     return;
 
   if (isa<BlockArgument>(repl)) {
-    arg.replaceAllUsesWith(repl);
+    rewriter.replaceAllUsesWith(arg, repl);
     return;
   }
 
@@ -1010,7 +1058,7 @@ void ReplaceBlockArgRewrite::commit() {
   // replacement value.
   Operation *replOp = cast<OpResult>(repl).getOwner();
   Block *replBlock = replOp->getBlock();
-  arg.replaceUsesWithIf(repl, [&](OpOperand &operand) {
+  rewriter.replaceUsesWithIf(arg, repl, [&](OpOperand &operand) {
     Operation *user = operand.getOwner();
     return user->getBlock() != replBlock || replOp->isBeforeInBlock(user);
   });
@@ -1018,14 +1066,40 @@ void ReplaceBlockArgRewrite::commit() {
 
 void ReplaceBlockArgRewrite::rollback() { rewriterImpl.mapping.erase(arg); }
 
-void ReplaceOperationRewrite::commit() {
-  for (OpResult result : op->getResults())
-    if (Value newValue =
-            rewriterImpl.mapping.lookupOrNull(result, result.getType()))
-      result.replaceAllUsesWith(newValue);
+void ReplaceOperationRewrite::commit(RewriterBase &rewriter) {
+  auto *listener = dyn_cast_or_null<RewriterBase::ForwardingListener>(
+      rewriter.getListener());
+
+  // Compute replacement values.
+  SmallVector<Value> replacements =
+      llvm::map_to_vector(op->getResults(), [&](OpResult result) {
+        return rewriterImpl.mapping.lookupOrNull(result, result.getType());
+      });
+
+  // Notify the listener that the operation is about to be replaced.
+  if (listener)
+    listener->notifyOperationReplaced(op, replacements);
+
+  // Replace all uses with the new values.
+  for (auto [result, newValue] :
+       llvm::zip_equal(op->getResults(), replacements))
+    if (newValue)
+      rewriter.replaceAllUsesWith(result, newValue);
+
+  // The original op will be erased, so remove it from the set of unlegalized
+  // ops.
   if (getConfig().unlegalizedOps)
     getConfig().unlegalizedOps->erase(op);
+
+  // Notify the listener that the operation (and its nested operations) was
+  // erased.
+  if (listener) {
+    op->walk<WalkOrder::PostOrder>(
+        [&](Operation *op) { listener->notifyOperationErased(op); });
+  }
+
   // Do not erase the operation yet. It may still be referenced in `mapping`.
+  // Just unlink it for now and erase it during cleanup.
   op->getBlock()->getOperations().remove(op);
 }
 
@@ -1034,7 +1108,9 @@ void ReplaceOperationRewrite::rollback() {
     rewriterImpl.mapping.erase(result);
 }
 
-void ReplaceOperationRewrite::cleanup() { eraseOp(op); }
+void ReplaceOperationRewrite::cleanup(RewriterBase &rewriter) {
+  rewriter.eraseOp(op);
+}
 
 void CreateOperationRewrite::rollback() {
   for (Region &region : op->getRegions()) {
@@ -1053,14 +1129,20 @@ void UnresolvedMaterializationRewrite::rollback() {
   op->erase();
 }
 
-void UnresolvedMaterializationRewrite::cleanup() { eraseOp(op); }
+void UnresolvedMaterializationRewrite::cleanup(RewriterBase &rewriter) {
+  rewriter.eraseOp(op);
+}
 
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Commit all rewrites.
+  IRRewriter rewriter(context, config.listener);
   for (auto &rewrite : rewrites)
-    rewrite->commit();
+    rewrite->commit(rewriter);
+
+  // Clean up all rewrites.
+  SingleEraseRewriter eraseRewriter(context);
   for (auto &rewrite : rewrites)
-    rewrite->cleanup();
+    rewrite->cleanup(eraseRewriter);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 62d776cd7573ee..8af8102adf9754 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -1,5 +1,10 @@
 // RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-legalize-patterns -verify-diagnostics %s | FileCheck %s
 
+//      CHECK: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_a
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_a
+
 // CHECK-LABEL: verifyDirectPattern
 func.func @verifyDirectPattern() -> i32 {
   // CHECK-NEXT:  "test.legal_op_a"() <{status = "Success"}
@@ -8,6 +13,16 @@ func.func @verifyDirectPattern() -> i32 {
   return %result : i32
 }
 
+// -----
+
+//      CHECK: notifyOperationInserted: test.illegal_op_e, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_c
+// CHECK-NEXT: notifyOperationModified: func.return
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_c
+// CHECK-NEXT: notifyOperationInserted: test.legal_op_a, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.illegal_op_e
+// CHECK-NEXT: notifyOperationErased: test.illegal_op_e
+
 // CHECK-LABEL: verifyLargerBenefit
 func.func @verifyLargerBenefit() -> i32 {
   // CHECK-NEXT:  "test.legal_op_a"() <{status = "Success"}
@@ -16,16 +31,24 @@ func.func @verifyLargerBenefit() -> i32 {
   return %result : i32
 }
 
+// -----
+
+// CHECK: notifyOperationModified: func.func
+// Note: No block insertion because this function is external and block
+// signature conversion is performed.
+
 // CHECK-LABEL: func private @remap_input_1_to_0()
 func.func private @remap_input_1_to_0(i16)
 
+// -----
+
 // CHECK-LABEL: func @remap_input_1_to_1(%arg0: f64)
 func.func @remap_input_1_to_1(%arg0: i64) {
   // CHECK-NEXT: "test.valid"{{.*}} : (f64)
   "test.invalid"(%arg0) : (i64) -> ()
 }
 
-// CHECK-LABEL: func @remap_call_1_to_1(%arg0: f64)
+// CHECK: func @remap_call_1_to_1(%arg0: f64)
 func.func @remap_call_1_to_1(%arg0: i64) {
   // CHECK-NEXT: call @remap_input_1_to_1(%arg0) : (f64) -> ()
   call @remap_input_1_to_1(%arg0) : (i64) -> ()
@@ -33,12 +56,36 @@ func.func @remap_call_1_to_1(%arg0: i64) {
   return
 }
 
+// -----
+
+// Block signature conversion: new block is inserted.
+// CHECK:      notifyBlockInserted into func.func: was unlinked
+
+// Contents of the old block are moved to the new block.
+// CHECK-NEXT: notifyOperationInserted: test.return, was linked, exact position unknown
+
+// The new block arguments are used in "test.return".
+// CHECK-NEXT: notifyOperationModified: test.return
+
+// The old block is erased.
+// CHECK-NEXT: notifyBlockErased
+
+// The function op gets a new type attribute.
+// CHECK-NEXT: notifyOperationModified: func.func
+
+// "test.return" is replaced.
+// CHECK-NEXT: notifyOperationInserted: test.return, was unlinked
+// CHECK-NEXT: notifyOperationReplaced: test.return
+// CHECK-NEXT: notifyOperationErased: test.return
+
 // CHECK-LABEL: func @remap_input_1_to_N({{.*}}f16, {{.*}}f16)
 func.func @remap_input_1_to_N(%arg0: f32) -> f32 {
   // CHECK-NEXT: "test.return"{{.*}} : (f16, f16) -> ()
   "test.return"(%arg0) : (f32) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_input_1_to_N_remaining_use(%arg0: f16, %arg1: f16)
 func.func @remap_input_1_to_N_remaining_use(%arg0: f32) {
   // CHECK-NEXT: [[CAST:%.*]] = "test.cast"(%arg0, %arg1) : (f16, f16) -> f32
@@ -54,6 +101,8 @@ func.func @remap_materialize_1_to_1(%arg0: i42) {
   "test.return"(%arg0) : (i42) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_input_to_self
 func.func @remap_input_to_self(%arg0: index) {
   // CHECK-NOT: test.cast
@@ -68,6 +117,8 @@ func.func @remap_multi(%arg0: i64, %unused: i16, %arg1: i64) -> (i64, i64) {
  "test.invalid"(%arg0, %arg1) : (i64, i64) -> ()
 }
 
+// -----
+
 // CHECK-LABEL: func @no_remap_nested
 func.func @no_remap_nested() {
   // CHECK-NEXT: "foo.region"
@@ -82,6 +133,8 @@ func.func @no_remap_nested() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_moved_region_args
 func.func @remap_moved_region_args() {
   // CHECK-NEXT: return
@@ -96,6 +149,8 @@ func.func @remap_moved_region_args() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @remap_cloned_region_args
 func.func @remap_cloned_region_args() {
   // CHECK-NEXT: return
@@ -122,6 +177,8 @@ func.func @remap_drop_region() {
   return
 }
 
+// -----
+
 // CHECK-LABEL: func @dropped_input_in_use
 func.func @dropped_input_in_use(%arg: i16, %arg2: i64) {
   // CHECK-NEXT: "test.cast"{{.*}} : () -> i16
@@ -130,6 +187,8 @@ func.func @dropped_input_in_use(%arg: i16, %arg2: ...
[truncated]

``````````

</details>


https://github.com/llvm/llvm-project/pull/83425


More information about the llvm-branch-commits mailing list