[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback (PR #136489)
Matthias Springer
llvmlistbot at llvm.org
Sun Apr 20 06:44:06 PDT 2025
https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/136489
When illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with `eraseOp`.
This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks.
This commit also adds a new optional parameter to `OpBuilder::tryFold` to get hold of the materialized constant ops.
>From 63e288819da19a199f692743d7b12226fb020e39 Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 20 Apr 2025 15:37:56 +0200
Subject: [PATCH] [mlir][Transforms] Dialect conversion: Erase materialized
constants instead of rollback
---
mlir/include/mlir/IR/Builders.h | 8 +++++--
mlir/lib/IR/Builders.cpp | 9 ++++++--
.../Transforms/Utils/DialectConversion.cpp | 21 ++++++++-----------
3 files changed, 22 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index cd8d3ee0af72b..8f13705fac96d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -564,9 +564,13 @@ class OpBuilder : public Builder {
/// Attempts to fold the given operation and places new results within
/// `results`. Returns success if the operation was folded, failure otherwise.
- /// If the fold was in-place, `results` will not be filled.
+ /// If the fold was in-place, `results` will not be filled. Optionally, newly
+ /// materialized constant operations can be returned to the caller.
+ ///
/// Note: This function does not erase the operation on a successful fold.
- LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
+ LogicalResult
+ tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVector<Operation *> *materializedConstants = nullptr);
/// Creates a deep copy of the specified operation, remapping any operands
/// that use values outside of the operation using the map that is provided
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 16bd8201ad50a..9450ef7738fa0 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
return create(state);
}
-LogicalResult OpBuilder::tryFold(Operation *op,
- SmallVectorImpl<Value> &results) {
+LogicalResult
+OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVector<Operation *> *materializedConstants) {
assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();
@@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op,
for (Operation *cst : generatedConstants)
insert(cst);
+ // Return materialized constant operations.
+ if (materializedConstants)
+ *materializedConstants = std::move(generatedConstants);
+
return success();
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a56fca25e1697..63225c6bbee7c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2090,8 +2090,6 @@ LogicalResult
OperationLegalizer::legalizeWithFold(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
- RewriterState curState = rewriterImpl.getCurrentState();
-
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
rewriterImpl.logger.indent();
@@ -2099,28 +2097,27 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// Try to fold the operation.
SmallVector<Value, 2> replacementValues;
+ SmallVector<Operation *> newOps;
rewriter.setInsertionPoint(op);
- if (failed(rewriter.tryFold(op, replacementValues))) {
+ if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
+
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
- i != e; ++i) {
- auto *createOp =
- dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
- if (!createOp)
- continue;
- if (failed(legalize(createOp->getOperation(), rewriter))) {
+ for (Operation *newOp : newOps) {
+ if (failed(legalize(newOp, rewriter))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
- createOp->getOperation()->getName()));
- rewriterImpl.resetState(curState);
+ newOp->getName()));
+ // Legalization failed: erase all materialized constants.
+ for (Operation *op : newOps)
+ rewriter.eraseOp(op);
return failure();
}
}
More information about the Mlir-commits
mailing list