[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Erase materialized constants instead of rollback (PR #136489)

Matthias Springer llvmlistbot at llvm.org
Mon Apr 21 23:55:46 PDT 2025


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/136489

>From 88b91b47063e8b8aea0c6a002d853e24c774ab8f Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Sun, 20 Apr 2025 15:37:56 +0200
Subject: [PATCH 1/2] [mlir][Transforms] Dialect conversion: Erase materialized
 constants instead of rollback

---
 mlir/include/mlir/IR/Builders.h               |  8 +++++--
 mlir/lib/IR/Builders.cpp                      |  9 ++++++--
 .../Transforms/Utils/DialectConversion.cpp    | 21 ++++++++-----------
 3 files changed, 22 insertions(+), 16 deletions(-)

diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index cd8d3ee0af72b..8f13705fac96d 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -564,9 +564,13 @@ class OpBuilder : public Builder {
 
   /// Attempts to fold the given operation and places new results within
   /// `results`. Returns success if the operation was folded, failure otherwise.
-  /// If the fold was in-place, `results` will not be filled.
+  /// If the fold was in-place, `results` will not be filled. Optionally, newly
+  /// materialized constant operations can be returned to the caller.
+  ///
   /// Note: This function does not erase the operation on a successful fold.
-  LogicalResult tryFold(Operation *op, SmallVectorImpl<Value> &results);
+  LogicalResult
+  tryFold(Operation *op, SmallVectorImpl<Value> &results,
+          SmallVector<Operation *> *materializedConstants = nullptr);
 
   /// Creates a deep copy of the specified operation, remapping any operands
   /// that use values outside of the operation using the map that is provided
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 16bd8201ad50a..9450ef7738fa0 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -465,8 +465,9 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
   return create(state);
 }
 
-LogicalResult OpBuilder::tryFold(Operation *op,
-                                 SmallVectorImpl<Value> &results) {
+LogicalResult
+OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
+                   SmallVector<Operation *> *materializedConstants) {
   assert(results.empty() && "expected empty results");
   ResultRange opResults = op->getResults();
 
@@ -528,6 +529,10 @@ LogicalResult OpBuilder::tryFold(Operation *op,
   for (Operation *cst : generatedConstants)
     insert(cst);
 
+  // Return materialized constant operations.
+  if (materializedConstants)
+    *materializedConstants = std::move(generatedConstants);
+
   return success();
 }
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 962207059c8aa..3059b35865bf2 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2090,8 +2090,6 @@ LogicalResult
 OperationLegalizer::legalizeWithFold(Operation *op,
                                      ConversionPatternRewriter &rewriter) {
   auto &rewriterImpl = rewriter.getImpl();
-  RewriterState curState = rewriterImpl.getCurrentState();
-
   LLVM_DEBUG({
     rewriterImpl.logger.startLine() << "* Fold {\n";
     rewriterImpl.logger.indent();
@@ -2099,28 +2097,27 @@ OperationLegalizer::legalizeWithFold(Operation *op,
 
   // Try to fold the operation.
   SmallVector<Value, 2> replacementValues;
+  SmallVector<Operation *> newOps;
   rewriter.setInsertionPoint(op);
-  if (failed(rewriter.tryFold(op, replacementValues))) {
+  if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));
     return failure();
   }
+
   // An empty list of replacement values indicates that the fold was in-place.
   // As the operation changed, a new legalization needs to be attempted.
   if (replacementValues.empty())
     return legalize(op, rewriter);
 
   // Recursively legalize any new constant operations.
-  for (unsigned i = curState.numRewrites, e = rewriterImpl.rewrites.size();
-       i != e; ++i) {
-    auto *createOp =
-        dyn_cast<CreateOperationRewrite>(rewriterImpl.rewrites[i].get());
-    if (!createOp)
-      continue;
-    if (failed(legalize(createOp->getOperation(), rewriter))) {
+  for (Operation *newOp : newOps) {
+    if (failed(legalize(newOp, rewriter))) {
       LLVM_DEBUG(logFailure(rewriterImpl.logger,
                             "failed to legalize generated constant '{0}'",
-                            createOp->getOperation()->getName()));
-      rewriterImpl.resetState(curState);
+                            newOp->getName()));
+      // Legalization failed: erase all materialized constants.
+      for (Operation *op : newOps)
+        rewriter.eraseOp(op);
       return failure();
     }
   }

>From 8bf5ca017d06c7a3f4fe9eca0aa512969912adcb Mon Sep 17 00:00:00 2001
From: Matthias Springer <mspringer at nvidia.com>
Date: Tue, 22 Apr 2025 08:55:20 +0200
Subject: [PATCH 2/2] address comments

---
 mlir/include/mlir/IR/Builders.h                 | 2 +-
 mlir/lib/IR/Builders.cpp                        | 2 +-
 mlir/lib/Transforms/Utils/DialectConversion.cpp | 2 +-
 3 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 8f13705fac96d..96dd14f142328 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -570,7 +570,7 @@ class OpBuilder : public Builder {
   /// Note: This function does not erase the operation on a successful fold.
   LogicalResult
   tryFold(Operation *op, SmallVectorImpl<Value> &results,
-          SmallVector<Operation *> *materializedConstants = nullptr);
+          SmallVectorImpl<Operation *> *materializedConstants = nullptr);
 
   /// Creates a deep copy of the specified operation, remapping any operands
   /// that use values outside of the operation using the map that is provided
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 9450ef7738fa0..89102115cdc40 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -467,7 +467,7 @@ Operation *OpBuilder::create(Location loc, StringAttr opName,
 
 LogicalResult
 OpBuilder::tryFold(Operation *op, SmallVectorImpl<Value> &results,
-                   SmallVector<Operation *> *materializedConstants) {
+                   SmallVectorImpl<Operation *> *materializedConstants) {
   assert(results.empty() && "expected empty results");
   ResultRange opResults = op->getResults();
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 3059b35865bf2..4d250329c6f45 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -2097,7 +2097,7 @@ OperationLegalizer::legalizeWithFold(Operation *op,
 
   // Try to fold the operation.
   SmallVector<Value, 2> replacementValues;
-  SmallVector<Operation *> newOps;
+  SmallVector<Operation *, 2> newOps;
   rewriter.setInsertionPoint(op);
   if (failed(rewriter.tryFold(op, replacementValues, &newOps))) {
     LLVM_DEBUG(logFailure(rewriterImpl.logger, "unable to fold"));



More information about the Mlir-commits mailing list