[Mlir-commits] [mlir] [mlir][Transforms] Dialect conversion: Assert when accessing erased ops (PR #83132)

Matthias Springer llvmlistbot at llvm.org
Wed Feb 28 01:06:50 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/83132

>From 239784dae21bd0ac127c8df829fe9d809ed39205 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 28 Feb 2024 09:04:27 +0000
Subject: [PATCH] [mlir][Transforms] Keep track of nested ignored/replaced ops

The dialect conversion maintains sets of "ignored" and "replaced" ops. This change simplifies the two sets, such that all nested ops are included. (This was previously not the case and sometimes only the parent op was included.)

This change allows for more aggressive assertions to prevent incorrect rewriter API usage. E.g., accessing ops/blocks/regions within an erased op.

A concrete example: I have seen conversion patterns in downstream projects where an op is replaced with a new op, and the region of the old op is afterwards inlined into the newly created op. This is invalid rewriter API usage: ops that were replaced/erased should not be accessed. Nested ops will be considered "ignored", even if they are moved to a different region after the region's parent op was erased (which is illegal API usage). Instead, create a new op, inline the regions, then replace the old op with the new op.

BEGIN_PUBLIC
No commit message needed for presubmit.
END_PUBLIC
---
 .../Transforms/Utils/DialectConversion.cpp    | 93 +++++++++++--------
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  1 -
 2 files changed, 55 insertions(+), 39 deletions(-)

diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index f967e8352bf4c8..5d399ce1eb9cf0 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -798,13 +798,12 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
                             PatternRewriter &rewriter, ValueRange values,
                             SmallVectorImpl<Value> &remapped);
 
-  /// Returns true if the given operation is ignored, and does not need to be
+  /// Return "true" if the given operation is ignored, and does not need to be
   /// converted.
   bool isOpIgnored(Operation *op) const;
 
-  /// Recursively marks the nested operations under 'op' as ignored. This
-  /// removes them from being considered for legalization.
-  void markNestedOpsIgnored(Operation *op);
+  /// Return "true" if the given operation was replaced or erased.
+  bool wasOpReplaced(Operation *op) const;
 
   //===--------------------------------------------------------------------===//
   // Type Conversion
@@ -946,18 +945,15 @@ struct ConversionPatternRewriterImpl : public RewriterBase::Listener {
   /// Ordered list of block operations (creations, splits, motions).
   SmallVector<std::unique_ptr<IRRewrite>> rewrites;
 
-  /// A set of operations that should no longer be considered for legalization,
-  /// but were not directly replace/erased/etc. by a pattern. These are
-  /// generally child operations of other operations who were
-  /// replaced/erased/etc. This is not meant to be an exhaustive list of all
-  /// operations, but the minimal set that can be used to detect if a given
-  /// operation should be `ignored`. For example, we may add the operations that
-  /// define non-empty regions to the set, but not any of the others. This
-  /// simplifies the amount of memory needed as we can query if the parent
-  /// operation was ignored.
+  /// A set of operations that should no longer be considered for legalization.
+  /// E.g., ops that are recursively legal. Ops that were replaced/erased are
+  /// tracked separately.
   SetVector<Operation *> ignoredOps;
 
-  // A set of operations that were erased.
+  /// A set of operations that were replaced/erased. Such ops are not erased
+  /// immediately but only when the dialect conversion succeeds. In the mean
+  /// time, they should no longer be considered for legalization and any attempt
+  /// to modify/access them is invalid rewriter API usage.
   SetVector<Operation *> replacedOps;
 
   /// The current type converter, or nullptr if no type converter is currently
@@ -1237,24 +1233,14 @@ LogicalResult ConversionPatternRewriterImpl::remapValues(
   return success();
 }
 
-// TODO: This function is a misnomer. It does not actually check if `op` is in
-// `ignoredOps`.
 bool ConversionPatternRewriterImpl::isOpIgnored(Operation *op) const {
-  // Check to see if this operation or the parent operation is ignored.
-  return ignoredOps.count(op->getParentOp()) || replacedOps.count(op);
+  // Check to see if this operation is ignored or was replaced.
+  return replacedOps.count(op) || ignoredOps.count(op);
 }
 
-void ConversionPatternRewriterImpl::markNestedOpsIgnored(Operation *op) {
-  // Walk this operation and collect nested operations that define non-empty
-  // regions. We mark such operations as 'ignored' so that we know we don't have
-  // to convert them, or their nested ops.
-  if (op->getNumRegions() == 0)
-    return;
-  op->walk([&](Operation *op) {
-    if (llvm::any_of(op->getRegions(),
-                     [](Region &region) { return !region.empty(); }))
-      ignoredOps.insert(op);
-  });
+bool ConversionPatternRewriterImpl::wasOpReplaced(Operation *op) const {
+  // Check to see if this operation was replaced.
+  return replacedOps.count(op);
 }
 
 //===----------------------------------------------------------------------===//
@@ -1476,6 +1462,9 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
+  assert(!wasOpReplaced(op->getParentOp()) &&
+         "attempting to insert into a block within a replaced/erased op");
+
   if (!previous.isSet()) {
     // This is a newly created op.
     appendRewrite<CreateOperationRewrite>(op);
@@ -1490,7 +1479,7 @@ void ConversionPatternRewriterImpl::notifyOperationInserted(
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
                                                      ValueRange newValues) {
   assert(newValues.size() == op->getNumResults());
-  assert(!replacedOps.contains(op) && "operation was already replaced");
+  assert(!ignoredOps.contains(op) && "operation was already replaced");
 
   // Track if any of the results changed, e.g. erased and replaced with null.
   bool resultChanged = false;
@@ -1509,10 +1498,8 @@ void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
   appendRewrite<ReplaceOperationRewrite>(op, currentTypeConverter,
                                          resultChanged);
 
-  // Mark this operation as recursively ignored so that we don't need to
-  // convert any nested operations.
-  replacedOps.insert(op);
-  markNestedOpsIgnored(op);
+  // Mark this operation and all nested ops as replaced.
+  op->walk([&](Operation *op) { replacedOps.insert(op); });
 }
 
 void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
@@ -1523,6 +1510,9 @@ void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
 
 void ConversionPatternRewriterImpl::notifyBlockInserted(
     Block *block, Region *previous, Region::iterator previousIt) {
+  assert(!wasOpReplaced(block->getParentOp()) &&
+         "attempting to insert into a region within a replaced/erased op");
+
   if (!previous) {
     // This is a newly created block.
     appendRewrite<CreateBlockRewrite>(block);
@@ -1604,6 +1594,9 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {
+  assert(!impl->wasOpReplaced(block->getParentOp()) &&
+         "attempting to erase a block within a replaced/erased op");
+
   // Mark all ops for erasure.
   for (Operation &op : *block)
     eraseOp(&op);
@@ -1619,18 +1612,27 @@ void ConversionPatternRewriter::eraseBlock(Block *block) {
 Block *ConversionPatternRewriter::applySignatureConversion(
     Region *region, TypeConverter::SignatureConversion &conversion,
     const TypeConverter *converter) {
+  assert(!impl->wasOpReplaced(region->getParentOp()) &&
+         "attempting to apply a signature conversion to a block within a "
+         "replaced/erased op");
   return impl->applySignatureConversion(region, conversion, converter);
 }
 
 FailureOr<Block *> ConversionPatternRewriter::convertRegionTypes(
     Region *region, const TypeConverter &converter,
     TypeConverter::SignatureConversion *entryConversion) {
+  assert(!impl->wasOpReplaced(region->getParentOp()) &&
+         "attempting to apply a signature conversion to a block within a "
+         "replaced/erased op");
   return impl->convertRegionTypes(region, converter, entryConversion);
 }
 
 LogicalResult ConversionPatternRewriter::convertNonEntryRegionTypes(
     Region *region, const TypeConverter &converter,
     ArrayRef<TypeConverter::SignatureConversion> blockConversions) {
+  assert(!impl->wasOpReplaced(region->getParentOp()) &&
+         "attempting to apply a signature conversion to a block within a "
+         "replaced/erased op");
   return impl->convertNonEntryRegionTypes(region, converter, blockConversions);
 }
 
@@ -1665,6 +1667,8 @@ ConversionPatternRewriter::getRemappedValues(ValueRange keys,
 
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {
+  assert(!impl->wasOpReplaced(block->getParentOp()) &&
+         "attempting to split a block within a replaced/erased op");
   auto *continuation = block->splitBlock(before);
   impl->notifySplitBlock(block, continuation);
   return continuation;
@@ -1673,15 +1677,19 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
 void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
                                                   Block::iterator before,
                                                   ValueRange argValues) {
+#ifndef NDEBUG
   assert(argValues.size() == source->getNumArguments() &&
          "incorrect # of argument replacement values");
-#ifndef NDEBUG
+  assert(!impl->wasOpReplaced(source->getParentOp()) &&
+         "attempting to inline a block from a replaced/erased op");
+  assert(!impl->wasOpReplaced(dest->getParentOp()) &&
+         "attempting to inline a block into a replaced/erased op");
   auto opIgnored = [&](Operation *op) { return impl->isOpIgnored(op); };
-#endif // NDEBUG
   // The source block will be deleted, so it should not have any users (i.e.,
   // there should be no predecessors).
   assert(llvm::all_of(source->getUsers(), opIgnored) &&
          "expected 'source' to have no predecessors");
+#endif // NDEBUG
 
   impl->notifyBlockBeingInlined(dest, source, before);
   for (auto it : llvm::zip(source->getArguments(), argValues))
@@ -1691,6 +1699,8 @@ void ConversionPatternRewriter::inlineBlockBefore(Block *source, Block *dest,
 }
 
 void ConversionPatternRewriter::startOpModification(Operation *op) {
+  assert(!impl->wasOpReplaced(op) &&
+         "attempting to modify a replaced/erased op");
 #ifndef NDEBUG
   impl->pendingRootUpdates.insert(op);
 #endif
@@ -1698,6 +1708,8 @@ void ConversionPatternRewriter::startOpModification(Operation *op) {
 }
 
 void ConversionPatternRewriter::finalizeOpModification(Operation *op) {
+  assert(!impl->wasOpReplaced(op) &&
+         "attempting to modify a replaced/erased op");
   PatternRewriter::finalizeOpModification(op);
   // There is nothing to do here, we only need to track the operation at the
   // start of the update.
@@ -1912,8 +1924,13 @@ OperationLegalizer::legalize(Operation *op,
 
     // If this operation is recursively legal, mark its children as ignored so
     // that we don't consider them for legalization.
-    if (legalityInfo->isRecursivelyLegal)
-      rewriter.getImpl().markNestedOpsIgnored(op);
+    if (legalityInfo->isRecursivelyLegal) {
+      op->walk([&](Operation *nested) {
+        if (op != nested)
+          rewriter.getImpl().ignoredOps.insert(nested);
+      });
+    }
+
     return success();
   }
 
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index bde4255ee4b368..abc0e43c7b7f2d 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1768,7 +1768,6 @@ struct TestMergeSingleBlockOps
     rewriter.inlineBlockBefore(&innerBlock, op);
     rewriter.eraseOp(innerTerminator);
     rewriter.eraseOp(op);
-    rewriter.modifyOpInPlace(op, [] {});
     return success();
   }
 };



More information about the Mlir-commits mailing list