[Mlir-commits] [mlir] [mlir][Transforms] Track erased ops separately (PR #83051)

Matthias Springer llvmlistbot at llvm.org
Mon Feb 26 11:46:11 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/83051

#83023 fixed a performance regression related to "ignored" ops. This broke some downstream projects that access ops after they were replaced (an API violation). This change restores the original behavior before #83023 (but without the performance regression), to give downstream users more time to fix their code.

>From 238322164f0c973318ddfcb9d66bcc5b05e7546c Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 26 Feb 2024 19:41:55 +0000
Subject: [PATCH] [mlir][Transforms] Track erased ops separately

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 .../Transforms/Utils/DialectConversion.cpp    | 23 ++++++++++++++-----
 1 file changed, 17 insertions(+), 6 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4165e0a52428f9..f967e8352bf4c8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numErased)
+                unsigned numErased, unsigned numReplacedOps)
       : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
-        numErased(numErased) {}
+        numErased(numErased), numReplacedOps(numReplacedOps) {}
 
   /// The current number of rewrites performed.
   unsigned numRewrites;
@@ -165,6 +165,9 @@ struct RewriterState {
 
   /// The current number of erased operations/blocks.
   unsigned numErased;
+
+  /// The current number of replaced ops that are scheduled for erasure.
+  unsigned numReplacedOps;
 };
 
 //===----------------------------------------------------------------------===//
@@ -954,6 +957,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// operation was ignored.
   SetVector<Operation *> ignoredOps;
 
+  // A set of operations that were erased.
+  SetVector<Operation *> replacedOps;
+
   /// The current type converter, or nullptr if no type converter is currently
   /// active.
   const TypeConverter *currentTypeConverter = nullptr;
@@ -1152,7 +1158,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(rewrites.size(), ignoredOps.size(),
-                       eraseRewriter.erased.size());
+                       eraseRewriter.erased.size(), replacedOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1165,6 +1171,9 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
 
   while (eraseRewriter.erased.size() != state.numErased)
     eraseRewriter.erased.pop_back();
+
+  while (replacedOps.size() != state.numReplacedOps)
+    replacedOps.pop_back();
 }
 
 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1228,9 +1237,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
   return success();
 }
 
+// TODO: This function is a misnomer. It does not actually check if `op` is in
+// `ignoredOps`.
 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
   // Check to see if this operation or the parent operation is ignored.
-  return ignoredOps.count(op->getParentOp()) || ignoredOps.count(op);
+  return ignoredOps.count(op->getParentOp()) || replacedOps.count(op);
 }
 
 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
@@ -1479,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
                                                      ValueRange newValues) {
   assert(newValues.size() == op->getNumResults());
-  assert(!ignoredOps.contains(op) && "operation was already replaced");
+  assert(!replacedOps.contains(op) && "operation was already replaced");
 
   // Track if any of the results changed, e.g. erased and replaced with null.
   bool resultChanged = false;
@@ -1500,7 +1511,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
 
   // Mark this operation as recursively ignored so that we don't need to
   // convert any nested operations.
-  ignoredOps.insert(op);
+  replacedOps.insert(op);
   markNestedOpsIgnored(op);
 }
 



More information about the Mlir-commits mailing list