[Mlir-commits] [mlir] [mlir][Transforms][NFC] Dialect Conversion: Keep `unresolvedMaterializations` up to date (PR #144254)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Sun Jun 15 01:16:24 PDT 2025


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

`unresolvedMaterializations` is a mapping from `UnrealizedConversionCastOp` to `UnresolvedMaterializationRewrite`. This mapping is needed to find the correct type converter for an unresolved materialization.

With this commit, `unresolvedMaterializations` is updated immediately when an op is being erased. This also cleans up the code base a bit: `SingleEraseRewriter` is now used only during the "cleanup" phase and no longer needed as a field of `ConversionRewriterImpl`.

This commit is in preparation of the One-Shot Dialect Conversion refactoring: `allowPatternRollback = false` will in the future trigger immediate materialization of all IR changes.


---
Full diff: https://github.com/llvm/llvm-project/pull/144254.diff


1 Files Affected:

- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-13) 


``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7de26d7cfa84d..b5345fb1a2dcb 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,7 +848,7 @@ namespace detail {
 struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
                                          const ConversionConfig &config)
-      : context(ctx), eraseRewriter(ctx), config(config) {}
+      : context(ctx), config(config) {}
 
   //===--------------------------------------------------------------------===//
   // State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// no new IR is created between calls to `eraseOp`/`eraseBlock`.
   struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
   public:
-    SingleEraseRewriter(MLIRContext *context)
-        : RewriterBase(context, /*listener=*/this) {}
+    SingleEraseRewriter(
+        MLIRContext *context,
+        llvm::function_ref<void(Operation *)> opErasedCallback = nullptr)
+        : RewriterBase(context, /*listener=*/this),
+          opErasedCallback(opErasedCallback) {}
 
     /// Erase the given op (unless it was already erased).
     void eraseOp(Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
 
     bool wasErased(void *ptr) const { return erased.contains(ptr); }
 
-    void notifyOperationErased(Operation *op) override { erased.insert(op); }
+    void notifyOperationErased(Operation *op) override {
+      erased.insert(op);
+      if (opErasedCallback)
+        opErasedCallback(op);
+    }
 
     void notifyBlockErased(Block *block) override { erased.insert(block); }
 
   private:
     /// Pointers to all erased operations and blocks.
     DenseSet<void *> erased;
+
+    /// A callback that is invoked when an operation is erased.
+    llvm::function_ref<void(Operation *)> opErasedCallback;
   };
 
   //===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// MLIR context.
   MLIRContext *context;
 
-  /// A rewriter that keeps track of ops/block that were already erased and
-  /// skips duplicate op/block erasures. This rewriter is used during the
-  /// "cleanup" phase.
-  SingleEraseRewriter eraseRewriter;
-
   // Mapping between replaced values that differ in type. This happens when
   // replacing a value with one of a different type.
   ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     rewrites[i]->commit(rewriter);
 
   // Clean up all rewrites.
+  SingleEraseRewriter eraseRewriter(
+      context, /*opErasedCallback=*/[&](Operation *op) {
+        if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
+          unresolvedMaterializations.erase(castOp);
+      });
   for (auto &rewrite : rewrites)
     rewrite->cleanup(eraseRewriter);
 }
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
   SmallVector<UnrealizedConversionCastOp> allCastOps;
   const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
       &materializations = rewriterImpl.unresolvedMaterializations;
-  for (auto it : materializations) {
-    if (rewriterImpl.eraseRewriter.wasErased(it.first))
-      continue;
+  for (auto it : materializations)
     allCastOps.push_back(it.first);
-  }
 
   // Reconcile all UnrealizedConversionCastOps that were inserted by the
   // dialect conversion frameworks. (Not the one that were inserted by

``````````

</details>


https://github.com/llvm/llvm-project/pull/144254


More information about the Mlir-commits mailing list