[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback (PR #136489)
Matthias Springer
llvmlistbot at llvm.org
Mon Apr 21 23:55:46 PDT 2025
https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/136489
>From 88b91b47063e8b8aea0c6a002d853e24c774ab8f Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 20 Apr 2025 15:37:56 +0200
Subject: [PATCH 1/2] [mlir][Transforms] Dialect conversion: Erase materialized
constants instead of rollback
---
mlir/include/mlir/IR/Builders.h | 8 +++++--
mlir/lib/IR/Builders.cpp | 9 ++++++--
.../Transforms/Utils/DialectConversion.cpp | 21 ++++++++-----------
3 files changed, 22 insertions(+), 16 deletions(-)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index cd8d3ee0af72b..8f13705fac96d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -564,9 +564,13 @@ class OpBuilder : public Builder {
/// Attempts to fold the given operation and places new results within
/// `results`. Returns success if the operation was folded, failure otherwise.
- /// If the fold was in-place, `results` will not be filled.
+ /// If the fold was in-place, `results` will not be filled. Optionally, newly
+ /// materialized constant operations can be returned to the caller.
+ ///
/// Note: This function does not erase the operation on a successful fold.
- LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
+ LogicalResult
+ tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVector<Operation *> *materializedConstants = nullptr);
/// Creates a deep copy of the specified operation, remapping any operands
/// that use values outside of the operation using the map that is provided
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 16bd8201ad50a..9450ef7738fa0 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
return create(state);
}
-LogicalResult OpBuilder::tryFold(Operation *op,
- SmallVectorImpl<Value> &results) {
+LogicalResult
+OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
+ SmallVector<Operation *> *materializedConstants) {
assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();
@@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op,
for (Operation *cst : generatedConstants)
insert(cst);
+ // Return materialized constant operations.
+ if (materializedConstants)
+ *materializedConstants = std::move(generatedConstants);
+
return success();
}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 962207059c8aa..3059b35865bf2 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2090,8 +2090,6 @@ LogicalResult
OperationLegalizer::legalizeWithFold(Operation *op,
ConversionPatternRewriter &rewriter) {
auto &rewriterImpl = rewriter.getImpl();
- RewriterState curState = rewriterImpl.getCurrentState();
-
LLVM_DEBUG({
rewriterImpl.logger.startLine() << "* Fold {\n";
rewriterImpl.logger.indent();
@@ -2099,28 +2097,27 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// Try to fold the operation.
SmallVector<Value, 2> replacementValues;
+ SmallVector<Operation *> newOps;
rewriter.setInsertionPoint(op);
- if (failed(rewriter.tryFold(op, replacementValues))) {
+ if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
return failure();
}
+
// An empty list of replacement values indicates that the fold was in-place.
// As the operation changed, a new legalization needs to be attempted.
if (replacementValues.empty())
return legalize(op, rewriter);
// Recursively legalize any new constant operations.
- for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
- i != e; ++i) {
- auto *createOp =
- dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
- if (!createOp)
- continue;
- if (failed(legalize(createOp->getOperation(), rewriter))) {
+ for (Operation *newOp : newOps) {
+ if (failed(legalize(newOp, rewriter))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger,
"failed to legalize generated constant '{0}'",
- createOp->getOperation()->getName()));
- rewriterImpl.resetState(curState);
+ newOp->getName()));
+ // Legalization failed: erase all materialized constants.
+ for (Operation *op : newOps)
+ rewriter.eraseOp(op);
return failure();
}
}
>From 8bf5ca017d06c7a3f4fe9eca0aa512969912adcb Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 22 Apr 2025 08:55:20 +0200
Subject: [PATCH 2/2] address comments
---
mlir/include/mlir/IR/Builders.h | 2 +-
mlir/lib/IR/Builders.cpp | 2 +-
mlir/lib/Transforms/Utils/DialectConversion.cpp | 2 +-
3 files changed, 3 insertions(+), 3 deletions(-)
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 8f13705fac96d..96dd14f142328 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -570,7 +570,7 @@ class OpBuilder : public Builder {
/// Note: This function does not erase the operation on a successful fold.
LogicalResult
tryFold(Operation *op, SmallVectorImpl<Value> &results,
- SmallVector<Operation *> *materializedConstants = nullptr);
+ SmallVectorImpl<Operation *> *materializedConstants = nullptr);
/// Creates a deep copy of the specified operation, remapping any operands
/// that use values outside of the operation using the map that is provided
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 9450ef7738fa0..89102115cdc40 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -467,7 +467,7 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
LogicalResult
OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
- SmallVector<Operation *> *materializedConstants) {
+ SmallVectorImpl<Operation *> *materializedConstants) {
assert(results.empty() && "expected empty results");
ResultRange opResults = op->getResults();
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3059b35865bf2..4d250329c6f45 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2097,7 +2097,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
// Try to fold the operation.
SmallVector<Value, 2> replacementValues;
- SmallVector<Operation *> newOps;
+ SmallVector<Operation *, 2> newOps;
rewriter.setInsertionPoint(op);
if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
More information about the Mlir-commits
mailing list