[Mlir-commits] [mlir] 7b66b5d - [mlir][Transforms] Track erased ops separately (#83051)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Feb 26 11:49:57 PST 2024


Author: Matthias Springer
Date: 2024-02-26T20:49:52+01:00
New Revision: 7b66b5d6c2fadb1e9f8cfc8d1864d4109105001f

URL: https://github.com/llvm/llvm-project/commit/7b66b5d6c2fadb1e9f8cfc8d1864d4109105001f
DIFF: https://github.com/llvm/llvm-project/commit/7b66b5d6c2fadb1e9f8cfc8d1864d4109105001f.diff

LOG: [mlir][Transforms] Track erased ops separately (#83051)

#83023 fixed a performance regression related to "ignored" ops. This
broke some downstream projects that access ops after they were replaced
(an API violation). This change restores the original behavior before
#83023 (but without the performance regression), to give downstream
users more time to fix their code.

Added: 
    

Modified: 
    mlir/lib/Transforms/Utils/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 4165e0a52428f9..f967e8352bf4c8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -153,9 +153,9 @@ namespace {
 /// This is useful when saving and undoing a set of rewrites.
 struct RewriterState {
   RewriterState(unsigned numRewrites, unsigned numIgnoredOperations,
-                unsigned numErased)
+                unsigned numErased, unsigned numReplacedOps)
       : numRewrites(numRewrites), numIgnoredOperations(numIgnoredOperations),
-        numErased(numErased) {}
+        numErased(numErased), numReplacedOps(numReplacedOps) {}
 
   /// The current number of rewrites performed.
   unsigned numRewrites;
@@ -165,6 +165,9 @@ struct RewriterState {
 
   /// The current number of erased operations/blocks.
   unsigned numErased;
+
+  /// The current number of replaced ops that are scheduled for erasure.
+  unsigned numReplacedOps;
 };
 
 //===----------------------------------------------------------------------===//
@@ -954,6 +957,9 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// operation was ignored.
   SetVector<Operation *> ignoredOps;
 
+  // A set of operations that were erased.
+  SetVector<Operation *> replacedOps;
+
   /// The current type converter, or nullptr if no type converter is currently
   /// active.
   const TypeConverter *currentTypeConverter = nullptr;
@@ -1152,7 +1158,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
 
 RewriterState ConversionPatternRewriterImpl::getCurrentState() {
   return RewriterState(rewrites.size(), ignoredOps.size(),
-                       eraseRewriter.erased.size());
+                       eraseRewriter.erased.size(), replacedOps.size());
 }
 
 void ConversionPatternRewriterImpl::resetState(RewriterState state) {
@@ -1165,6 +1171,9 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
 
   while (eraseRewriter.erased.size() != state.numErased)
     eraseRewriter.erased.pop_back();
+
+  while (replacedOps.size() != state.numReplacedOps)
+    replacedOps.pop_back();
 }
 
 void ConversionPatternRewriterImpl::undoRewrites(unsigned numRewritesToKeep) {
@@ -1228,9 +1237,11 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
   return success();
 }
 
+// TODO: This function is a misnomer. It does not actually check if `op` is in
+// `ignoredOps`.
 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
   // Check to see if this operation or the parent operation is ignored.
-  return ignoredOps.count(op->getParentOp()) || ignoredOps.count(op);
+  return ignoredOps.count(op->getParentOp()) || replacedOps.count(op);
 }
 
 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
@@ -1479,7 +1490,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
                                                      ValueRange newValues) {
   assert(newValues.size() == op->getNumResults());
-  assert(!ignoredOps.contains(op) && "operation was already replaced");
+  assert(!replacedOps.contains(op) && "operation was already replaced");
 
   // Track if any of the results changed, e.g. erased and replaced with null.
   bool resultChanged = false;
@@ -1500,7 +1511,7 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
 
   // Mark this operation as recursively ignored so that we don't need to
   // convert any nested operations.
-  ignoredOps.insert(op);
+  replacedOps.insert(op);
   markNestedOpsIgnored(op);
 }
 


        


More information about the Mlir-commits mailing list