[Mlir-commits] [mlir] 2d3bbb6 - [mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback (#136489)
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Tue Apr 22 00:12:03 PDT 2025
Author: Matthias Springer
Date: 2025-04-22T09:12:00+02:00
New Revision: 2d3bbb6aafbc74ef6fc51286f09def0f0e35fe14
URL: https://github.com/llvm/llvm-project/commit/2d3bbb6aafbc74ef6fc51286f09def0f0e35fe14
DIFF: https://github.com/llvm/llvm-project/commit/2d3bbb6aafbc74ef6fc51286f09def0f0e35fe14.diff
LOG: [mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback (#136489)
When illegal (and not legalizable) constant operations are materialized
during a dialect conversion as part of op folding, these operations must
be deleted again. This used to be implemented via the rollback
mechanism. This commit switches the implementation to regular rewriter
API usage: simply delete the materialized constants with `eraseOp`.
This commit is in preparation of the One-Shot Dialect Conversion
refactoring, which will disallow IR rollbacks.
This commit also adds a new optional parameter to `OpBuilder::tryFold`
to get hold of the materialized constant ops.
Added:
Modified:
mlir/include/mlir/IR/Builders.h
mlir/lib/IR/Builders.cpp
mlir/lib/Transforms/Utils/DialectConversion.cpp
Removed:
################################################################################
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index cd8d3ee0af72b..96dd14f142328 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -564,9 +564,13 @@ class OpBuilder : public Builder {
/// Attempts to fold the given operation and places new results within
/// `results`. Returns success if the operation was folded, failure otherwise.
- /// If the fold was in-place, `results` will not be filled.
+ /// If the fold was in-place, `results` will not be filled. Optionally, newly
+ /// materialized constant operations can be returned to the caller.
+ ///
/// Note: This function does not erase the operation on a successful fold.
- LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
+ LogicalResult
+ tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVectorImpl<Operation *> *materializedConstants = nullptr);
/// Creates a deep copy of the specified operation, remapping any operands
/// that use values outside of the operation using the map that is provided
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 16bd8201ad50a..89102115cdc40 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
return create(state);
}
-LogicalResult OpBuilder::tryFold(Operation *op,
- SmallVectorImpl<Value> &results) {
+LogicalResult
+OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVectorImpl<Operation *> *materializedConstants) {
assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();
@@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op,
for (Operation *cst : generatedConstants)
insert(cst);
+ // Return materialized constant operations.
+ if (materializedConstants)
+ *materializedConstants = std::move(generatedConstants);
+
return success();
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 962207059c8aa..4d250329c6f45 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2090,8 +2090,6 @@ LogicalResult
OperationLegalizer::legalizeWithFold(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
- RewriterState curState = rewriterImpl.getCurrentState();
-
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
rewriterImpl.logger.indent();
@@ -2099,28 +2097,27 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// Try to fold the operation.
SmallVector<Value, 2> replacementValues;
+ SmallVector<Operation *, 2> newOps;
rewriter.setInsertionPoint(op);
- if (failed(rewriter.tryFold(op, replacementValues))) {
+ if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
+
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
- i != e; ++i) {
- auto *createOp =
- dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
- if (!createOp)
- continue;
- if (failed(legalize(createOp->getOperation(), rewriter))) {
+ for (Operation *newOp : newOps) {
+ if (failed(legalize(newOp, rewriter))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
- createOp->getOperation()->getName()));
- rewriterImpl.resetState(curState);
+ newOp->getName()));
+ // Legalization failed: erase all materialized constants.
+ for (Operation *op : newOps)
+ rewriter.eraseOp(op);
return failure();
}
}
More information about the Mlir-commits
mailing list