[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback (PR #136489)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Sun Apr 20 06:44:40 PDT 2025
llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core
Author: Matthias Springer (matthias-springer)
<details>
<summary>Changes</summary>
When illegal (and not legalizable) constant operations are materialized during a dialect conversion as part of op folding, these operations must be deleted again. This used to be implemented via the rollback mechanism. This commit switches the implementation to regular rewriter API usage: simply delete the materialized constants with `eraseOp`.
This commit is in preparation of the One-Shot Dialect Conversion refactoring, which will disallow IR rollbacks.
This commit also adds a new optional parameter to `OpBuilder::tryFold` to get hold of the materialized constant ops.
---
Full diff: https://github.com/llvm/llvm-project/pull/136489.diff
3 Files Affected:
- (modified) mlir/include/mlir/IR/Builders.h (+6-2)
- (modified) mlir/lib/IR/Builders.cpp (+7-2)
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+9-12)
``````````diff
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index cd8d3ee0af72b..8f13705fac96d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -564,9 +564,13 @@ class OpBuilder : public Builder {
/// Attempts to fold the given operation and places new results within
/// `results`. Returns success if the operation was folded, failure otherwise.
- /// If the fold was in-place, `results` will not be filled.
+ /// If the fold was in-place, `results` will not be filled. Optionally, newly
+ /// materialized constant operations can be returned to the caller.
+ ///
/// Note: This function does not erase the operation on a successful fold.
- LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
+ LogicalResult
+ tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVector<Operation *> *materializedConstants = nullptr);
/// Creates a deep copy of the specified operation, remapping any operands
/// that use values outside of the operation using the map that is provided
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 16bd8201ad50a..9450ef7738fa0 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
return create(state);
}
-LogicalResult OpBuilder::tryFold(Operation *op,
- SmallVectorImpl<Value> &results) {
+LogicalResult
+OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVector<Operation *> *materializedConstants) {
assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();
@@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op,
for (Operation *cst : generatedConstants)
insert(cst);
+ // Return materialized constant operations.
+ if (materializedConstants)
+ *materializedConstants = std::move(generatedConstants);
+
return success();
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index a56fca25e1697..63225c6bbee7c 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2090,8 +2090,6 @@ LogicalResult
OperationLegalizer::legalizeWithFold(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
- RewriterState curState = rewriterImpl.getCurrentState();
-
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
rewriterImpl.logger.indent();
@@ -2099,28 +2097,27 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// Try to fold the operation.
SmallVector<Value, 2> replacementValues;
+ SmallVector<Operation *> newOps;
rewriter.setInsertionPoint(op);
- if (failed(rewriter.tryFold(op, replacementValues))) {
+ if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
+
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
- i != e; ++i) {
- auto *createOp =
- dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
- if (!createOp)
- continue;
- if (failed(legalize(createOp->getOperation(), rewriter))) {
+ for (Operation *newOp : newOps) {
+ if (failed(legalize(newOp, rewriter))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
- createOp->getOperation()->getName()));
- rewriterImpl.resetState(curState);
+ newOp->getName()));
+ // Legalization failed: erase all materialized constants.
+ for (Operation *op : newOps)
+ rewriter.eraseOp(op);
return failure();
}
}
``````````
</details>
https://github.com/llvm/llvm-project/pull/136489
More information about the Mlir-commits
mailing list