[Mlir-commits] [mlir] df48026 - [mlir] DialectConversion: support erasing blocks

Alex Zinenko llvmlistbot at llvm.org
Wed May 20 07:15:02 PDT 2020


Author: Alex Zinenko
Date: 2020-05-20T16:12:05+02:00
New Revision: df48026b4c30d88cc1221883631ac8aa7c4c376b

URL: https://github.com/llvm/llvm-project/commit/df48026b4c30d88cc1221883631ac8aa7c4c376b
DIFF: https://github.com/llvm/llvm-project/commit/df48026b4c30d88cc1221883631ac8aa7c4c376b.diff

LOG: [mlir] DialectConversion: support erasing blocks

PatternRewriter has support for erasing a Block from its parent region, but
this feature has not been implemented for ConversionPatternRewriter that needs
to keep track of and be able to undo block actions. Introduce support for
undoing block erasure in the ConversionPatternRewriter by marking all the ops
it contains for erasure and by detaching the block from its parent region. The
detached block is stored in the action description and is not actually deleted
until the rewrites are applied.

Differential Revision: https://reviews.llvm.org/D80135

Added: 
    

Modified: 
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 851b6817bdfc..4dfb9b1de84a 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -509,7 +509,7 @@ struct ConversionPatternRewriterImpl {
 
   /// The kind of the block action performed during the rewrite.  Actions can be
   /// undone if the conversion fails.
-  enum class BlockActionKind { Create, Move, Split, TypeConversion };
+  enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion };
 
   /// Original position of the given block in its parent region.  We cannot use
   /// a region iterator because it could have been invalidated by other region
@@ -525,6 +525,9 @@ struct ConversionPatternRewriterImpl {
     static BlockAction getCreate(Block *block) {
       return {BlockActionKind::Create, block, {}};
     }
+    static BlockAction getErase(Block *block, BlockPosition originalPos) {
+      return {BlockActionKind::Erase, block, {originalPos}};
+    }
     static BlockAction getMove(Block *block, BlockPosition originalPos) {
       return {BlockActionKind::Move, block, {originalPos}};
     }
@@ -544,9 +547,9 @@ struct ConversionPatternRewriterImpl {
     Block *block;
 
     union {
-      // In use if kind == BlockActionKind::Move and contains a pointer to the
-      // region that originally contained the block as well as the position of
-      // the block in that region.
+      // In use if kind == BlockActionKind::Move or BlockActionKind::Erase, and
+      // contains a pointer to the region that originally contained the block as
+      // well as the position of the block in that region.
       BlockPosition originalPosition;
       // In use if kind == BlockActionKind::Split and contains a pointer to the
       // block that was split into two parts.
@@ -564,6 +567,10 @@ struct ConversionPatternRewriterImpl {
   /// Reset the state of the rewriter to a previously saved point.
   void resetState(RewriterState state);
 
+  /// Erase any blocks that were unlinked from their regions and stored in block
+  /// actions.
+  void eraseDanglingBlocks();
+
   /// Undo the block actions (motions, splits) one by one in reverse order until
   /// "numActionsToKeep" actions remains.
   void undoBlockActions(unsigned numActionsToKeep = 0);
@@ -587,6 +594,9 @@ struct ConversionPatternRewriterImpl {
   /// PatternRewriter hook for replacing the results of an operation.
   void replaceOp(Operation *op, ValueRange newValues);
 
+  /// Notifies that a block is about to be erased.
+  void notifyBlockIsBeingErased(Block *block);
+
   /// Notifies that a block was created.
   void notifyCreatedBlock(Block *block);
 
@@ -711,6 +721,14 @@ void ConversionPatternRewriterImpl::resetState(RewriterState state) {
     ignoredOps.pop_back();
 }
 
+void ConversionPatternRewriterImpl::eraseDanglingBlocks() {
+  for (auto &action : blockActions) {
+    if (action.kind != BlockActionKind::Erase)
+      continue;
+    delete action.block;
+  }
+}
+
 void ConversionPatternRewriterImpl::undoBlockActions(
     unsigned numActionsToKeep) {
   for (auto &action :
@@ -727,6 +745,14 @@ void ConversionPatternRewriterImpl::undoBlockActions(
       action.block->erase();
       break;
     }
+    // Put the block (owned by action) back into its original position.
+    case BlockActionKind::Erase: {
+      auto &blockList = action.originalPosition.region->getBlocks();
+      blockList.insert(
+          std::next(blockList.begin(), action.originalPosition.position),
+          action.block);
+      break;
+    }
     // Move the block back to its original position.
     case BlockActionKind::Move: {
       Region *originalRegion = action.originalPosition.region;
@@ -806,6 +832,9 @@ void ConversionPatternRewriterImpl::applyRewrites() {
     repl.op->erase();
 
   argConverter.applyRewrites(mapping);
+
+  // Now that the ops have been erased, also erase dangling blocks.
+  eraseDanglingBlocks();
 }
 
 LogicalResult
@@ -853,6 +882,12 @@ void ConversionPatternRewriterImpl::replaceOp(Operation *op,
   markNestedOpsIgnored(op);
 }
 
+void ConversionPatternRewriterImpl::notifyBlockIsBeingErased(Block *block) {
+  Region *region = block->getParent();
+  auto position = std::distance(region->begin(), Region::iterator(block));
+  blockActions.push_back(BlockAction::getErase(block, {region, position}));
+}
+
 void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
   blockActions.push_back(BlockAction::getCreate(block));
 }
@@ -942,7 +977,17 @@ void ConversionPatternRewriter::eraseOp(Operation *op) {
 }
 
 void ConversionPatternRewriter::eraseBlock(Block *block) {
-  llvm_unreachable("erasing blocks for dialect conversion not implemented");
+  impl->notifyBlockIsBeingErased(block);
+
+  // Mark all ops for erasure.
+  for (Operation &op : *block)
+    eraseOp(&op);
+
+  // Unlink the block from its parent region. The block is kept in the block
+  // action and will be actually destroyed when rewrites are applied. This
+  // allows us to keep the operations in the block live and undo the removal by
+  // re-inserting the block.
+  block->getParent()->getBlocks().remove(block);
 }
 
 /// Apply a signature conversion to the entry block of the given region.
@@ -1334,7 +1379,8 @@ OperationLegalizer::legalizePattern(Operation *op, RewritePattern *pattern,
        i != e; ++i) {
     auto &action = rewriterImpl.blockActions[i];
     if (action.kind ==
-        ConversionPatternRewriterImpl::BlockActionKind::TypeConversion)
+            ConversionPatternRewriterImpl::BlockActionKind::TypeConversion ||
+        action.kind == ConversionPatternRewriterImpl::BlockActionKind::Erase)
       continue;
 
     // Convert the block signature.

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index 5f1411ce92cf..98f350053c04 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -236,6 +236,27 @@ func @undo_block_arg_replace() {
 
 // -----
 
+// The op in this function is rewritten to itself (and thus remains illegal) by
+// a pattern that removes its second block after adding an operation into it.
+// Check that we can undo block removal succesfully.
+// CHECK-LABEL: @undo_block_erase
+func @undo_block_erase() {
+  // CHECK: test.undo_block_erase
+  "test.undo_block_erase"() ({
+    // expected-remark at -1 {{not legalizable}}
+    // CHECK: "unregistered.return"()[^[[BB:.*]]]
+    "unregistered.return"()[^bb1] : () -> ()
+    // expected-remark at -1 {{not legalizable}}
+  // CHECK: ^[[BB]]
+  ^bb1:
+    // CHECK: unregistered.return
+    "unregistered.return"() : () -> ()
+    // expected-remark at -1 {{not legalizable}}
+  }) : () -> ()
+}
+
+// -----
+
 // The op in this function is attempted to be rewritten to another illegal op
 // with an attached region containing an invalid terminator. The terminator is
 // created before the parent op. The deletion should not crash when deleting

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 268d0eaf85e2..df3068cc6487 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -281,6 +281,23 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
   }
 };
 
+/// A rewrite pattern that tests the undo mechanism when erasing a block.
+struct TestUndoBlockErase : public ConversionPattern {
+  TestUndoBlockErase(MLIRContext *ctx)
+      : ConversionPattern("test.undo_block_erase", /*benefit=*/1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    Block *secondBlock = &*std::next(op->getRegion(0).begin());
+    rewriter.setInsertionPointToStart(secondBlock);
+    rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+    rewriter.eraseBlock(secondBlock);
+    rewriter.updateRootInPlace(op, [] {});
+    return success();
+  }
+};
+
 //===----------------------------------------------------------------------===//
 // Type-Conversion Rewrite Testing
 
@@ -504,14 +521,14 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
-    patterns.insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
-                    TestCreateBlock, TestCreateIllegalBlock,
-                    TestUndoBlockArgReplace, TestPassthroughInvalidOp,
-                    TestSplitReturnType, TestChangeProducerTypeI32ToF32,
-                    TestChangeProducerTypeF32ToF64,
-                    TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
-                    TestNonRootReplacement, TestBoundedRecursiveRewrite,
-                    TestNestedOpCreationUndoRewrite>(&getContext());
+    patterns.insert<
+        TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
+        TestCreateIllegalBlock, TestUndoBlockArgReplace, TestUndoBlockErase,
+        TestPassthroughInvalidOp, TestSplitReturnType,
+        TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+        TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+        TestNonRootReplacement, TestBoundedRecursiveRewrite,
+        TestNestedOpCreationUndoRewrite>(&getContext());
     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);


        


More information about the Mlir-commits mailing list