[Mlir-commits] [mlir] e888886 - [mlir][DialectConversion] Add support for mergeBlocks in ConversionPatternRewriter.

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Aug 3 10:06:37 PDT 2020


Author: MaheshRavishankar
Date: 2020-08-03T10:06:04-07:00
New Revision: e888886cc3daf2c2d6c20cad51cd5ec2ffc24789

URL: https://github.com/llvm/llvm-project/commit/e888886cc3daf2c2d6c20cad51cd5ec2ffc24789
DIFF: https://github.com/llvm/llvm-project/commit/e888886cc3daf2c2d6c20cad51cd5ec2ffc24789.diff

LOG: [mlir][DialectConversion] Add support for mergeBlocks in ConversionPatternRewriter.

Differential Revision: https://reviews.llvm.org/D84795

Added: 
    mlir/test/Transforms/test-merge-blocks.mlir

Modified: 
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/lib/Dialect/Test/TestOps.td
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 9778958a4588..713f0b73dfe0 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -602,7 +602,14 @@ struct OpReplacement {
 
 /// The kind of the block action performed during the rewrite.  Actions can be
 /// undone if the conversion fails.
-enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion };
+enum class BlockActionKind {
+  Create,
+  Erase,
+  Merge,
+  Move,
+  Split,
+  TypeConversion
+};
 
 /// Original position of the given block in its parent region.  We cannot use
 /// a region iterator because it could have been invalidated by other region
@@ -612,6 +619,15 @@ struct BlockPosition {
   Region::iterator::
diff erence_type position;
 };
 
+/// Information needed to undo the merge actions.
+/// - the source block, and
+/// - the Operation that was the last operation in the dest block before the
+///   merge (could be null if the dest block was empty).
+struct MergeInfo {
+  Block *sourceBlock;
+  Operation *destBlockLastInst;
+};
+
 /// The storage class for an undoable block action (one of BlockActionKind),
 /// contains the information necessary to undo this action.
 struct BlockAction {
@@ -621,6 +637,11 @@ struct BlockAction {
   static BlockAction getErase(Block *block, BlockPosition originalPos) {
     return {BlockActionKind::Erase, block, {originalPos}};
   }
+  static BlockAction getMerge(Block *block, Block *sourceBlock) {
+    BlockAction action{BlockActionKind::Merge, block, {}};
+    action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
+    return action;
+  }
   static BlockAction getMove(Block *block, BlockPosition originalPos) {
     return {BlockActionKind::Move, block, {originalPos}};
   }
@@ -647,6 +668,9 @@ struct BlockAction {
     // In use if kind == BlockActionKind::Split and contains a pointer to the
     // block that was split into two parts.
     Block *originalBlock;
+    // In use if kind == BlockActionKind::Merge, and contains the information
+    // needed to undo the merge.
+    MergeInfo mergeInfo;
   };
 };
 } // end anonymous namespace
@@ -738,6 +762,9 @@ struct ConversionPatternRewriterImpl {
   /// Notifies that a block was split.
   void notifySplitBlock(Block *block, Block *continuation);
 
+  /// Notifies that `block` is being merged with `srcBlock`.
+  void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
+
   /// Notifies that the blocks of a region are about to be moved.
   void notifyRegionIsBeingInlinedBefore(Region &region, Region &parent,
                                         Region::iterator before);
@@ -966,6 +993,20 @@ void ConversionPatternRewriterImpl::undoBlockActions(
           action.block);
       break;
     }
+    // Split the block at the position which was originally the end of the
+    // destination block (owned by action), and put the instructions back into
+    // the block used before the merge.
+    case BlockActionKind::Merge: {
+      Block *sourceBlock = action.mergeInfo.sourceBlock;
+      Block::iterator splitPoint =
+          (action.mergeInfo.destBlockLastInst
+               ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
+               : action.block->begin());
+      sourceBlock->getOperations().splice(sourceBlock->begin(),
+                                          action.block->getOperations(),
+                                          splitPoint, action.block->end());
+      break;
+    }
     // Move the block back to its original position.
     case BlockActionKind::Move: {
       Region *originalRegion = action.originalPosition.region;
@@ -1161,6 +1202,11 @@ void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
   blockActions.push_back(BlockAction::getSplit(continuation, block));
 }
 
+void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,
+                                                            Block *srcBlock) {
+  blockActions.push_back(BlockAction::getMerge(block, srcBlock));
+}
+
 void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
     Region &region, Region &parent, Region::iterator before) {
   for (auto &pair : llvm::enumerate(region)) {
@@ -1283,9 +1329,16 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
 /// PatternRewriter hook for merging a block into another.
 void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
                                             ValueRange argValues) {
-  // TODO: This requires fixing the implementation of
-  // 'replaceUsesOfBlockArgument', which currently isn't undoable.
-  llvm_unreachable("block merging updates are currently not supported");
+  impl->notifyBlocksBeingMerged(dest, source);
+  assert(llvm::all_of(source->getPredecessors(),
+                      [dest](Block *succ) { return succ == dest; }) &&
+         "expected 'source' to have no predecessors or only 'dest'");
+  assert(argValues.size() == source->getNumArguments() &&
+         "incorrect # of argument replacement values");
+  for (auto it : llvm::zip(source->getArguments(), argValues))
+    replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+  dest->getOperations().splice(dest->end(), source->getOperations());
+  eraseBlock(source);
 }
 
 /// PatternRewriter hook for moving blocks out of a region.

diff  --git a/mlir/test/Transforms/test-merge-blocks.mlir b/mlir/test/Transforms/test-merge-blocks.mlir
new file mode 100644
index 000000000000..65dd50569416
--- /dev/null
+++ b/mlir/test/Transforms/test-merge-blocks.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-merge-blocks -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @merge_blocks
+func @merge_blocks(%arg0: i32, %arg1 : i32) -> () {
+  //      CHECK: "test.merge_blocks"() ( {
+  // CHECK-NEXT:   "test.return"
+  // CHECK-NEXT: })
+  // CHECK-NEXT: "test.return"
+  %0:2 = "test.merge_blocks"() ({
+  ^bb0:
+     "test.br"(%arg0, %arg1)[^bb1] : (i32, i32) -> ()
+  ^bb1(%arg3 : i32, %arg4 : i32):
+     "test.return"(%arg3, %arg4) : (i32, i32) -> ()
+  }) : () -> (i32, i32)
+  "test.return"(%0#0, %0#1) : (i32, i32) -> ()
+}
+
+// -----
+
+// The op in this function is rewritten to itself (and thus remains
+// illegal) by a pattern that merges the second block with the first
+// after adding an operation into it.  Check that we can undo block
+// removal succesfully.
+// CHECK-LABEL: @undo_blocks_merge
+func @undo_blocks_merge(%arg0: i32) {
+  "test.undo_blocks_merge"() ({
+    // expected-remark at -1 {{op 'test.undo_blocks_merge' is not legalizable}}
+    // CHECK: "unregistered.return"(%{{.*}})[^[[BB:.*]]]
+    "unregistered.return"(%arg0)[^bb1] : (i32) -> ()
+    // expected-remark at -1 {{op 'unregistered.return' is not legalizable}}
+  // CHECK: ^[[BB]]
+  ^bb1(%arg1 : i32):
+    // CHECK: "unregistered.return"
+    "unregistered.return"(%arg1) : (i32) -> ()
+    // expected-remark at -1 {{op 'unregistered.return' is not legalizable}}
+  }) : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @inline_regions()
+func @inline_regions() -> ()
+{
+  //      CHECK: test.SingleBlockImplicitTerminator
+  // CHECK-NEXT:   %[[T0:.*]] = "test.type_producer"
+  // CHECK-NEXT:   "test.type_consumer"(%[[T0]])
+  // CHECK-NEXT:   "test.finish"
+  "test.SingleBlockImplicitTerminator"() ({
+  ^bb0:
+    %0 = "test.type_producer"() : () -> i32
+    "test.SingleBlockImplicitTerminator"() ({
+    ^bb1:
+      "test.type_consumer"(%0) : (i32) -> ()
+      "test.finish"() : () -> ()
+    }) : () -> ()
+    "test.finish"() : () -> ()
+  }) : () -> ()
+  "test.return"() : () -> ()
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index af20034abd9b..0c26f8a719c0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1163,6 +1163,16 @@ def TestTypeConsumerOp : TEST_Op<"type_consumer">,
 def TestValidOp : TEST_Op<"valid", [Terminator]>,
   Arguments<(ins Variadic<AnyType>)>;
 
+def TestMergeBlocksOp : TEST_Op<"merge_blocks"> {
+  let summary = "merge_blocks operation";
+  let description = [{
+    Test op with multiple blocks that are merged with Dialect Conversion"
+  }];
+
+  let regions = (region AnyRegion:$body);
+  let results = (outs Variadic<AnyType>:$result);
+}
+
 //===----------------------------------------------------------------------===//
 // Test parser.
 //===----------------------------------------------------------------------===//

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 5bc947fc8c91..f6607a5f5524 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -834,6 +834,118 @@ struct TestTypeConversionDriver
 };
 } // end anonymous namespace
 
+namespace {
+/// A rewriter pattern that tests that blocks can be merged.
+struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
+  using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    Block &firstBlock = op.body().front();
+    Operation *branchOp = firstBlock.getTerminator();
+    Block *secondBlock = &*(std::next(op.body().begin()));
+    auto succOperands = branchOp->getOperands();
+    SmallVector<Value, 2> replacements(succOperands);
+    rewriter.eraseOp(branchOp);
+    rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
+    rewriter.updateRootInPlace(op, [] {});
+    return success();
+  }
+};
+
+/// A rewrite pattern to tests the undo mechanism of blocks being merged.
+struct TestUndoBlocksMerge : public ConversionPattern {
+  TestUndoBlocksMerge(MLIRContext *ctx)
+      : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    Block &firstBlock = op->getRegion(0).front();
+    Operation *branchOp = firstBlock.getTerminator();
+    Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
+    rewriter.setInsertionPointToStart(secondBlock);
+    rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+    auto succOperands = branchOp->getOperands();
+    SmallVector<Value, 2> replacements(succOperands);
+    rewriter.eraseOp(branchOp);
+    rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
+    rewriter.updateRootInPlace(op, [] {});
+    return success();
+  }
+};
+
+/// A rewrite mechanism to inline the body of the op into its parent, when both
+/// ops can have a single block.
+struct TestMergeSingleBlockOps
+    : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
+  using OpConversionPattern<
+      SingleBlockImplicitTerminatorOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const final {
+    SingleBlockImplicitTerminatorOp parentOp =
+        op.getParentOfType<SingleBlockImplicitTerminatorOp>();
+    if (!parentOp)
+      return failure();
+    Block &parentBlock = parentOp.region().front();
+    Block &innerBlock = op.region().front();
+    TerminatorOp innerTerminator =
+        cast<TerminatorOp>(innerBlock.getTerminator());
+    Block *parentPrologue =
+        rewriter.splitBlock(&parentBlock, Block::iterator(op));
+    rewriter.eraseOp(innerTerminator);
+    rewriter.mergeBlocks(&innerBlock, &parentBlock, {});
+    rewriter.eraseOp(op);
+    rewriter.mergeBlocks(parentPrologue, &parentBlock, {});
+    rewriter.updateRootInPlace(op, [] {});
+    return success();
+  }
+};
+
+struct TestMergeBlocksPatternDriver
+    : public PassWrapper<TestMergeBlocksPatternDriver,
+                         OperationPass<ModuleOp>> {
+  void runOnOperation() override {
+    mlir::OwningRewritePatternList patterns;
+    MLIRContext *context = &getContext();
+    patterns
+        .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
+            context);
+    ConversionTarget target(*context);
+    target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp,
+                      TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp,
+                      TestReturnOp>();
+    target.addIllegalOp<ILLegalOpF>();
+
+    /// Expect the op to have a single block after legalization.
+    target.addDynamicallyLegalOp<TestMergeBlocksOp>(
+        [&](TestMergeBlocksOp op) -> bool {
+          return llvm::hasSingleElement(op.body());
+        });
+
+    /// Only allow `test.br` within test.merge_blocks op.
+    target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
+      return op.getParentOfType<TestMergeBlocksOp>();
+    });
+
+    /// Expect that all nested test.SingleBlockImplicitTerminator ops are
+    /// inlined.
+    target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
+        [&](SingleBlockImplicitTerminatorOp op) -> bool {
+          return !op.getParentOfType<SingleBlockImplicitTerminatorOp>();
+        });
+
+    DenseSet<Operation *> unlegalizedOps;
+    (void)applyPartialConversion(getOperation(), target, patterns,
+                                 &unlegalizedOps);
+    for (auto *op : unlegalizedOps)
+      op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
+  }
+};
+} // namespace
+
 //===----------------------------------------------------------------------===//
 // PassRegistration
 //===----------------------------------------------------------------------===//
@@ -866,5 +978,9 @@ void registerPatternsTestPass() {
   PassRegistration<TestTypeConversionDriver>(
       "test-legalize-type-conversion",
       "Test various type conversion functionalities in DialectConversion");
+
+  PassRegistration<TestMergeBlocksPatternDriver>{
+      "test-merge-blocks",
+      "Test Merging operation in ConversionPatternRewriter"};
 }
 } // namespace mlir


        


More information about the Mlir-commits mailing list