[Mlir-commits] [mlir] [mlir][Transforms][NFC] Dialect Conversion: Keep `unresolvedMaterializations` up to date (PR #144254)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Jun 15 01:16:24 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
`unresolvedMaterializations` is a mapping from `UnrealizedConversionCastOp` to `UnresolvedMaterializationRewrite`. This mapping is needed to find the correct type converter for an unresolved materialization.
With this commit, `unresolvedMaterializations` is updated immediately when an op is being erased. This also cleans up the code base a bit: `SingleEraseRewriter` is now used only during the "cleanup" phase and no longer needed as a field of `ConversionRewriterImpl`.
This commit is in preparation of the One-Shot Dialect Conversion refactoring: `allowPatternRollback = false` will in the future trigger immediate materialization of all IR changes.
---
Full diff: https://github.com/llvm/llvm-project/pull/144254.diff
1 Files Affected:
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+20-13)
``````````diff
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 7de26d7cfa84d..b5345fb1a2dcb 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -848,7 +848,7 @@ namespace detail {
struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
explicit ConversionPatternRewriterImpl(MLIRContext *ctx,
const ConversionConfig &config)
- : context(ctx), eraseRewriter(ctx), config(config) {}
+ : context(ctx), config(config) {}
//===--------------------------------------------------------------------===//
// State Management
@@ -981,8 +981,11 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// no new IR is created between calls to `eraseOp`/`eraseBlock`.
struct SingleEraseRewriter : public RewriterBase, RewriterBase::Listener {
public:
- SingleEraseRewriter(MLIRContext *context)
- : RewriterBase(context, /*listener=*/this) {}
+ SingleEraseRewriter(
+ MLIRContext *context,
+ llvm::function_ref<void(Operation *)> opErasedCallback = nullptr)
+ : RewriterBase(context, /*listener=*/this),
+ opErasedCallback(opErasedCallback) {}
/// Erase the given op (unless it was already erased).
void eraseOp(Operation *op) override {
@@ -1003,13 +1006,20 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
bool wasErased(void *ptr) const { return erased.contains(ptr); }
- void notifyOperationErased(Operation *op) override { erased.insert(op); }
+ void notifyOperationErased(Operation *op) override {
+ erased.insert(op);
+ if (opErasedCallback)
+ opErasedCallback(op);
+ }
void notifyBlockErased(Block *block) override { erased.insert(block); }
private:
/// Pointers to all erased operations and blocks.
DenseSet<void *> erased;
+
+ /// A callback that is invoked when an operation is erased.
+ llvm::function_ref<void(Operation *)> opErasedCallback;
};
//===--------------------------------------------------------------------===//
@@ -1019,11 +1029,6 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
/// MLIR context.
MLIRContext *context;
- /// A rewriter that keeps track of ops/block that were already erased and
- /// skips duplicate op/block erasures. This rewriter is used during the
- /// "cleanup" phase.
- SingleEraseRewriter eraseRewriter;
-
// Mapping between replaced values that differ in type. This happens when
// replacing a value with one of a different type.
ConversionValueMapping mapping;
@@ -1195,6 +1200,11 @@ void ConversionPatternRewriterImpl::applyRewrites() {
rewrites[i]->commit(rewriter);
// Clean up all rewrites.
+ SingleEraseRewriter eraseRewriter(
+ context, /*opErasedCallback=*/[&](Operation *op) {
+ if (auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
+ unresolvedMaterializations.erase(castOp);
+ });
for (auto &rewrite : rewrites)
rewrite->cleanup(eraseRewriter);
}
@@ -2714,11 +2724,8 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
SmallVector<UnrealizedConversionCastOp> allCastOps;
const DenseMap<UnrealizedConversionCastOp, UnresolvedMaterializationRewrite *>
&materializations = rewriterImpl.unresolvedMaterializations;
- for (auto it : materializations) {
- if (rewriterImpl.eraseRewriter.wasErased(it.first))
- continue;
+ for (auto it : materializations)
allCastOps.push_back(it.first);
- }
// Reconcile all UnrealizedConversionCastOps that were inserted by the
// dialect conversion frameworks. (Not the one that were inserted by
``````````
</details>
https://github.com/llvm/llvm-project/pull/144254
More information about the Mlir-commits
mailing list