[Mlir-commits] [mlir] b98f414 - [mlir][DialectConversion] Emit an error if an operation marked as erased has live users after conversion

River Riddle llvmlistbot at llvm.org
Tue Jul 14 13:06:28 PDT 2020


Author: River Riddle
Date: 2020-07-14T13:06:08-07:00
New Revision: b98f414a04e19202669a4273e620bc12b5054413

URL: https://github.com/llvm/llvm-project/commit/b98f414a04e19202669a4273e620bc12b5054413
DIFF: https://github.com/llvm/llvm-project/commit/b98f414a04e19202669a4273e620bc12b5054413.diff

LOG: [mlir][DialectConversion] Emit an error if an operation marked as erased has live users after conversion

Up until now, there has been an implicit agreement that when an operation is marked as
"erased" all uses of that operation's results are guaranteed to be removed during conversion. How this works in practice is that there is either an assert/crash/asan failure/etc. This revision adds support for properly detecting when an erased operation has dangling users, emits and error and fails the conversion.

Differential Revision: https://reviews.llvm.org/D82830

Added: 
    mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir

Modified: 
    mlir/lib/Transforms/DialectConversion.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 9401121eed96..b9ed64f573f2 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -519,10 +519,9 @@ class OperationTransactionState {
 /// This class represents one requested operation replacement via 'replaceOp'.
 struct OpReplacement {
   OpReplacement() = default;
-  OpReplacement(Operation *op, ValueRange newValues)
-      : op(op), newValues(newValues.begin(), newValues.end()) {}
+  OpReplacement(ValueRange newValues)
+      : newValues(newValues.begin(), newValues.end()) {}
 
-  Operation *op;
   SmallVector<Value, 2> newValues;
 };
 
@@ -681,8 +680,8 @@ struct ConversionPatternRewriterImpl {
   /// Ordered vector of all of the newly created operations during conversion.
   std::vector<Operation *> createdOps;
 
-  /// Ordered vector of any requested operation replacements.
-  SmallVector<OpReplacement, 4> replacements;
+  /// Ordered map of requested operation replacements.
+  llvm::MapVector<Operation *, OpReplacement> replacements;
 
   /// Ordered vector of any requested block argument replacements.
   SmallVector<BlockArgument, 4> argReplacements;
@@ -690,18 +689,29 @@ struct ConversionPatternRewriterImpl {
   /// Ordered list of block operations (creations, splits, motions).
   SmallVector<BlockAction, 4> blockActions;
 
-  /// A set of operations that have been erased/replaced/etc that should no
-  /// longer be considered for legalization. This is not meant to be an
-  /// exhaustive list of all operations, but the minimal set that can be used to
-  /// detect if a given operation should be `ignored`. For example, we may add
-  /// the operations that define non-empty regions to the set, but not any of
-  /// the others. This simplifies the amount of memory needed as we can query if
-  /// the parent operation was ignored.
+  /// A set of operations that should no longer be considered for legalization,
+  /// but were not directly replace/erased/etc. by a pattern. These are
+  /// generally child operations of other operations who were
+  /// replaced/erased/etc. This is not meant to be an exhaustive list of all
+  /// operations, but the minimal set that can be used to detect if a given
+  /// operation should be `ignored`. For example, we may add the operations that
+  /// define non-empty regions to the set, but not any of the others. This
+  /// simplifies the amount of memory needed as we can query if the parent
+  /// operation was ignored.
   llvm::SetVector<Operation *> ignoredOps;
 
   /// A transaction state for each of operations that were updated in-place.
   SmallVector<OperationTransactionState, 4> rootUpdates;
 
+  /// A vector of indices into `replacements` of operations that were replaced
+  /// with values with 
diff erent result types than the original operation, e.g.
+  /// 1->N conversion of some kind.
+  SmallVector<unsigned, 4> operationsWithChangedResults;
+
+  /// A default type converter, used when block conversions do not have one
+  /// explicitly provided.
+  TypeConverter defaultTypeConverter;
+
 #ifndef NDEBUG
   /// A set of operations that have pending updates. This tracking isn't
   /// strictly necessary, and is thus only active during debug builds for extra
@@ -711,10 +721,6 @@ struct ConversionPatternRewriterImpl {
   /// A logger used to emit diagnostics during the conversion process.
   llvm::ScopedPrinter logger{llvm::dbgs()};
 #endif
-
-  /// A default type converter, used when block conversions do not have one
-  /// explicitly provided.
-  TypeConverter defaultTypeConverter;
 };
 } // end namespace detail
 } // end namespace mlir
@@ -728,10 +734,13 @@ struct ConversionPatternRewriterImpl {
 /// does not need to collect nested ops recursively because it is expected to
 /// also be called for each nested op when it is about to be deleted.
 static void detachNestedAndErase(Operation *op) {
-  for (Region &region : op->getRegions())
-    for (Block &block : region.getBlocks())
+  for (Region &region : op->getRegions()) {
+    for (Block &block : region.getBlocks()) {
       while (!block.getOperations().empty())
         block.getOperations().remove(block.getOperations().begin());
+      block.dropAllDefinedValueUses();
+    }
+  }
   op->erase();
 }
 
@@ -750,16 +759,16 @@ void ConversionPatternRewriterImpl::discardRewrites() {
 void ConversionPatternRewriterImpl::applyRewrites() {
   // Apply all of the rewrites replacements requested during conversion.
   for (auto &repl : replacements) {
-    for (unsigned i = 0, e = repl.newValues.size(); i != e; ++i) {
-      if (auto newValue = repl.newValues[i])
-        repl.op->getResult(i).replaceAllUsesWith(
+    for (unsigned i = 0, e = repl.second.newValues.size(); i != e; ++i) {
+      if (auto newValue = repl.second.newValues[i])
+        repl.first->getResult(i).replaceAllUsesWith(
             mapping.lookupOrDefault(newValue));
     }
 
     // If this operation defines any regions, drop any pending argument
     // rewrites.
-    if (repl.op->getNumRegions())
-      argConverter.notifyOpRemoved(repl.op);
+    if (repl.first->getNumRegions())
+      argConverter.notifyOpRemoved(repl.first);
   }
 
   // Apply all of the requested argument replacements.
@@ -785,7 +794,7 @@ void ConversionPatternRewriterImpl::applyRewrites() {
   // allows processing nested operations before their parent region is
   // destroyed.
   for (auto &repl : llvm::reverse(replacements))
-    repl.op->erase();
+    repl.first->erase();
 
   argConverter.applyRewrites(mapping);
 
@@ -819,9 +828,10 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
 
   // Reset any replaced operations and undo any saved mappings.
   for (auto &repl : llvm::drop_begin(replacements, state.numReplacements))
-    for (auto result : repl.op->getResults())
+    for (auto result : repl.first->getResults())
       mapping.erase(result);
-  replacements.resize(state.numReplacements);
+  while (replacements.size() != state.numReplacements)
+    replacements.pop_back();
 
   // Pop all of the newly created operations.
   while (createdOps.size() != state.numCreatedOps) {
@@ -832,6 +842,11 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
   // Pop all of the recorded ignored operations that are no longer valid.
   while (ignoredOps.size() != state.numIgnoredOperations)
     ignoredOps.pop_back();
+
+  // Reset operations with changed results.
+  while (!operationsWithChangedResults.empty() &&
+         operationsWithChangedResults.back() >= state.numReplacements)
+    operationsWithChangedResults.pop_back();
 }
 
 void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
@@ -898,8 +913,8 @@ void ConversionPatternRewriterImpl::remapValues(
 }
 
 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
-  // Check to see if this operation or its parent were ignored.
-  return ignoredOps.count(op) || ignoredOps.count(op->getParentOp());
+  // Check to see if this operation was replaced or its parent ignored.
+  return replacements.count(op) || ignoredOps.count(op->getParentOp());
 }
 
 void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
@@ -963,14 +978,25 @@ FailureOr<Block *> ConversionPatternRewriterImpl::convertRegionTypes(
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
                                                      ValueRange newValues) {
   assert(newValues.size() == op->getNumResults());
+  assert(!replacements.count(op) && "operation was already replaced");
+
+  // Track if any of the results changed, e.g. erased and replaced with null.
+  bool resultChanged = false;
 
   // Create mappings for each of the new result values.
-  for (unsigned i = 0, e = newValues.size(); i < e; ++i)
-    if (auto repl = newValues[i])
-      mapping.map(op->getResult(i), repl);
+  Value newValue, result;
+  for (auto it : llvm::zip(newValues, op->getResults())) {
+    std::tie(newValue, result) = it;
+    if (!newValue)
+      resultChanged = true;
+    else
+      mapping.map(result, newValue);
+  }
+  if (resultChanged)
+    operationsWithChangedResults.push_back(replacements.size());
 
   // Record the requested operation replacement.
-  replacements.emplace_back(op, newValues);
+  replacements.insert(std::make_pair(op, OpReplacement(newValues)));
 
   // Mark this operation as recursively ignored so that we don't need to
   // convert any nested operations.
@@ -1511,20 +1537,12 @@ LogicalResult OperationLegalizer::legalizePatternResult(
   assert(impl.pendingRootUpdates.empty() && "dangling root updates");
 #endif
 
-  // Check all of the replacements to ensure that the pattern actually replaced
-  // the root operation. We also mark any other replaced ops as 'dead' so that
-  // we don't try to legalize them later.
-  bool replacedRoot = false;
-  for (unsigned i = curState.numReplacements, e = impl.replacements.size();
-       i != e; ++i) {
-    Operation *replacedOp = impl.replacements[i].op;
-    if (replacedOp == op)
-      replacedRoot = true;
-    else
-      impl.ignoredOps.insert(replacedOp);
-  }
-
-  // Check that the root was either updated or replace.
+  // Check that the root was either replaced or updated in place.
+  auto replacedRoot = [&] {
+    return llvm::any_of(
+        llvm::drop_begin(impl.replacements, curState.numReplacements),
+        [op](auto &it) { return it.first == op; });
+  };
   auto updatedRootInPlace = [&] {
     return llvm::any_of(
         llvm::drop_begin(impl.rootUpdates, curState.numRootUpdates),
@@ -1532,7 +1550,7 @@ LogicalResult OperationLegalizer::legalizePatternResult(
   };
   (void)replacedRoot;
   (void)updatedRootInPlace;
-  assert((replacedRoot || updatedRootInPlace()) &&
+  assert((replacedRoot() || updatedRootInPlace()) &&
          "expected pattern to replace the root operation");
 
   // Legalize each of the actions registered during application.
@@ -1856,6 +1874,10 @@ struct OperationConverter {
   /// Converts an operation with the given rewriter.
   LogicalResult convert(ConversionPatternRewriter &rewriter, Operation *op);
 
+  /// This method is called after the conversion process to legalize any
+  /// remaining artifacts and complete the conversion.
+  LogicalResult finalize(ConversionPatternRewriter &rewriter);
+
   /// The legalizer to use when converting operations.
   OperationLegalizer opLegalizer;
 
@@ -1916,16 +1938,56 @@ LogicalResult OperationConverter::convertOperations(ArrayRef<Operation *> ops) {
 
   // Convert each operation and discard rewrites on failure.
   ConversionPatternRewriter rewriter(ops.front()->getContext());
+  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
   for (auto *op : toConvert)
     if (failed(convert(rewriter, op)))
-      return rewriter.getImpl().discardRewrites(), failure();
+      return rewriterImpl.discardRewrites(), failure();
 
-  // Otherwise, the body conversion succeeded. Apply rewrites if this is not an
-  // analysis conversion.
+  // Now that all of the operations have been converted, finalize the conversion
+  // process to ensure any lingering conversion artifacts are cleaned up and
+  // legalized.
+  if (failed(finalize(rewriter)))
+    return rewriterImpl.discardRewrites(), failure();
+
+  // After a successful conversion, apply rewrites if this is not an analysis
+  // conversion.
   if (mode == OpConversionMode::Analysis)
-    rewriter.getImpl().discardRewrites();
+    rewriterImpl.discardRewrites();
   else
-    rewriter.getImpl().applyRewrites();
+    rewriterImpl.applyRewrites();
+  return success();
+}
+
+LogicalResult
+OperationConverter::finalize(ConversionPatternRewriter &rewriter) {
+  ConversionPatternRewriterImpl &rewriterImpl = rewriter.getImpl();
+  auto isOpDead = [&](Operation *op) { return rewriterImpl.isOpIgnored(op); };
+
+  // Process the operations with changed results.
+  for (unsigned replIdx : rewriterImpl.operationsWithChangedResults) {
+    auto &repl = *(rewriterImpl.replacements.begin() + replIdx);
+    for (auto it : llvm::zip(repl.first->getResults(), repl.second.newValues)) {
+      Value result = std::get<0>(it), newValue = std::get<1>(it);
+
+      // If the operation result was replaced with null, all of the uses of this
+      // value should be replaced.
+      if (newValue)
+        continue;
+
+      auto liveUserIt = llvm::find_if_not(result.getUsers(), isOpDead);
+      if (liveUserIt != result.user_end()) {
+        InFlightDiagnostic diag = repl.first->emitError()
+                                  << "failed to legalize operation '"
+                                  << repl.first->getName()
+                                  << "' marked as erased";
+        diag.attachNote(liveUserIt->getLoc())
+            << "found live user of result #"
+            << result.cast<OpResult>().getResultNumber() << ": " << *liveUserIt;
+        return failure();
+      }
+    }
+  }
+
   return success();
 }
 

diff  --git a/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir b/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir
new file mode 100644
index 000000000000..34c46d1cfc86
--- /dev/null
+++ b/mlir/test/Transforms/test-legalize-erased-op-with-uses.mlir
@@ -0,0 +1,10 @@
+// RUN: mlir-opt %s -test-legalize-unknown-root-patterns -verify-diagnostics
+
+// Test that an error is emitted when an operation is marked as "erased", but
+// has users that live across the conversion.
+func @remove_all_ops(%arg0: i32) -> i32 {
+  // expected-error at below {{failed to legalize operation 'test.illegal_op_a' marked as erased}}
+  %0 = "test.illegal_op_a"() : () -> i32
+  // expected-note at below {{found live user of result #0: return %0 : i32}}
+  return %0 : i32
+}


        


More information about the Mlir-commits mailing list