[Mlir-commits] [mlir] e888886 - [mlir][DialectConversion] Add support for mergeBlocks in ConversionPatternRewriter.
llvmlistbot at llvm.org
llvmlistbot at llvm.org
Mon Aug 3 10:06:37 PDT 2020
Author: MaheshRavishankar
Date: 2020-08-03T10:06:04-07:00
New Revision: e888886cc3daf2c2d6c20cad51cd5ec2ffc24789
URL: https://github.com/llvm/llvm-project/commit/e888886cc3daf2c2d6c20cad51cd5ec2ffc24789
DIFF: https://github.com/llvm/llvm-project/commit/e888886cc3daf2c2d6c20cad51cd5ec2ffc24789.diff
LOG: [mlir][DialectConversion] Add support for mergeBlocks in ConversionPatternRewriter.
Differential Revision: https://reviews.llvm.org/D84795
Added:
mlir/test/Transforms/test-merge-blocks.mlir
Modified:
mlir/lib/Transforms/DialectConversion.cpp
mlir/test/lib/Dialect/Test/TestOps.td
mlir/test/lib/Dialect/Test/TestPatterns.cpp
Removed:
################################################################################
diff --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 9778958a4588..713f0b73dfe0 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -602,7 +602,14 @@ struct OpReplacement {
/// The kind of the block action performed during the rewrite. Actions can be
/// undone if the conversion fails.
-enum class BlockActionKind { Create, Erase, Move, Split, TypeConversion };
+enum class BlockActionKind {
+ Create,
+ Erase,
+ Merge,
+ Move,
+ Split,
+ TypeConversion
+};
/// Original position of the given block in its parent region. We cannot use
/// a region iterator because it could have been invalidated by other region
@@ -612,6 +619,15 @@ struct BlockPosition {
Region::iterator::
diff erence_type position;
};
+/// Information needed to undo the merge actions.
+/// - the source block, and
+/// - the Operation that was the last operation in the dest block before the
+/// merge (could be null if the dest block was empty).
+struct MergeInfo {
+ Block *sourceBlock;
+ Operation *destBlockLastInst;
+};
+
/// The storage class for an undoable block action (one of BlockActionKind),
/// contains the information necessary to undo this action.
struct BlockAction {
@@ -621,6 +637,11 @@ struct BlockAction {
static BlockAction getErase(Block *block, BlockPosition originalPos) {
return {BlockActionKind::Erase, block, {originalPos}};
}
+ static BlockAction getMerge(Block *block, Block *sourceBlock) {
+ BlockAction action{BlockActionKind::Merge, block, {}};
+ action.mergeInfo = {sourceBlock, block->empty() ? nullptr : &block->back()};
+ return action;
+ }
static BlockAction getMove(Block *block, BlockPosition originalPos) {
return {BlockActionKind::Move, block, {originalPos}};
}
@@ -647,6 +668,9 @@ struct BlockAction {
// In use if kind == BlockActionKind::Split and contains a pointer to the
// block that was split into two parts.
Block *originalBlock;
+ // In use if kind == BlockActionKind::Merge, and contains the information
+ // needed to undo the merge.
+ MergeInfo mergeInfo;
};
};
} // end anonymous namespace
@@ -738,6 +762,9 @@ struct ConversionPatternRewriterImpl {
/// Notifies that a block was split.
void notifySplitBlock(Block *block, Block *continuation);
+ /// Notifies that `block` is being merged with `srcBlock`.
+ void notifyBlocksBeingMerged(Block *block, Block *srcBlock);
+
/// Notifies that the blocks of a region are about to be moved.
void notifyRegionIsBeingInlinedBefore(Region ®ion, Region &parent,
Region::iterator before);
@@ -966,6 +993,20 @@ void ConversionPatternRewriterImpl::undoBlockActions(
action.block);
break;
}
+ // Split the block at the position which was originally the end of the
+ // destination block (owned by action), and put the instructions back into
+ // the block used before the merge.
+ case BlockActionKind::Merge: {
+ Block *sourceBlock = action.mergeInfo.sourceBlock;
+ Block::iterator splitPoint =
+ (action.mergeInfo.destBlockLastInst
+ ? ++Block::iterator(action.mergeInfo.destBlockLastInst)
+ : action.block->begin());
+ sourceBlock->getOperations().splice(sourceBlock->begin(),
+ action.block->getOperations(),
+ splitPoint, action.block->end());
+ break;
+ }
// Move the block back to its original position.
case BlockActionKind::Move: {
Region *originalRegion = action.originalPosition.region;
@@ -1161,6 +1202,11 @@ void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
blockActions.push_back(BlockAction::getSplit(continuation, block));
}
+void ConversionPatternRewriterImpl::notifyBlocksBeingMerged(Block *block,
+ Block *srcBlock) {
+ blockActions.push_back(BlockAction::getMerge(block, srcBlock));
+}
+
void ConversionPatternRewriterImpl::notifyRegionIsBeingInlinedBefore(
Region ®ion, Region &parent, Region::iterator before) {
for (auto &pair : llvm::enumerate(region)) {
@@ -1283,9 +1329,16 @@ Block *ConversionPatternRewriter::splitBlock(Block *block,
/// PatternRewriter hook for merging a block into another.
void ConversionPatternRewriter::mergeBlocks(Block *source, Block *dest,
ValueRange argValues) {
- // TODO: This requires fixing the implementation of
- // 'replaceUsesOfBlockArgument', which currently isn't undoable.
- llvm_unreachable("block merging updates are currently not supported");
+ impl->notifyBlocksBeingMerged(dest, source);
+ assert(llvm::all_of(source->getPredecessors(),
+ [dest](Block *succ) { return succ == dest; }) &&
+ "expected 'source' to have no predecessors or only 'dest'");
+ assert(argValues.size() == source->getNumArguments() &&
+ "incorrect # of argument replacement values");
+ for (auto it : llvm::zip(source->getArguments(), argValues))
+ replaceUsesOfBlockArgument(std::get<0>(it), std::get<1>(it));
+ dest->getOperations().splice(dest->end(), source->getOperations());
+ eraseBlock(source);
}
/// PatternRewriter hook for moving blocks out of a region.
diff --git a/mlir/test/Transforms/test-merge-blocks.mlir b/mlir/test/Transforms/test-merge-blocks.mlir
new file mode 100644
index 000000000000..65dd50569416
--- /dev/null
+++ b/mlir/test/Transforms/test-merge-blocks.mlir
@@ -0,0 +1,59 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -test-merge-blocks -verify-diagnostics %s | FileCheck %s
+
+// CHECK-LABEL: @merge_blocks
+func @merge_blocks(%arg0: i32, %arg1 : i32) -> () {
+ // CHECK: "test.merge_blocks"() ( {
+ // CHECK-NEXT: "test.return"
+ // CHECK-NEXT: })
+ // CHECK-NEXT: "test.return"
+ %0:2 = "test.merge_blocks"() ({
+ ^bb0:
+ "test.br"(%arg0, %arg1)[^bb1] : (i32, i32) -> ()
+ ^bb1(%arg3 : i32, %arg4 : i32):
+ "test.return"(%arg3, %arg4) : (i32, i32) -> ()
+ }) : () -> (i32, i32)
+ "test.return"(%0#0, %0#1) : (i32, i32) -> ()
+}
+
+// -----
+
+// The op in this function is rewritten to itself (and thus remains
+// illegal) by a pattern that merges the second block with the first
+// after adding an operation into it. Check that we can undo block
+// removal succesfully.
+// CHECK-LABEL: @undo_blocks_merge
+func @undo_blocks_merge(%arg0: i32) {
+ "test.undo_blocks_merge"() ({
+ // expected-remark at -1 {{op 'test.undo_blocks_merge' is not legalizable}}
+ // CHECK: "unregistered.return"(%{{.*}})[^[[BB:.*]]]
+ "unregistered.return"(%arg0)[^bb1] : (i32) -> ()
+ // expected-remark at -1 {{op 'unregistered.return' is not legalizable}}
+ // CHECK: ^[[BB]]
+ ^bb1(%arg1 : i32):
+ // CHECK: "unregistered.return"
+ "unregistered.return"(%arg1) : (i32) -> ()
+ // expected-remark at -1 {{op 'unregistered.return' is not legalizable}}
+ }) : () -> ()
+}
+
+// -----
+
+// CHECK-LABEL: @inline_regions()
+func @inline_regions() -> ()
+{
+ // CHECK: test.SingleBlockImplicitTerminator
+ // CHECK-NEXT: %[[T0:.*]] = "test.type_producer"
+ // CHECK-NEXT: "test.type_consumer"(%[[T0]])
+ // CHECK-NEXT: "test.finish"
+ "test.SingleBlockImplicitTerminator"() ({
+ ^bb0:
+ %0 = "test.type_producer"() : () -> i32
+ "test.SingleBlockImplicitTerminator"() ({
+ ^bb1:
+ "test.type_consumer"(%0) : (i32) -> ()
+ "test.finish"() : () -> ()
+ }) : () -> ()
+ "test.finish"() : () -> ()
+ }) : () -> ()
+ "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td
index af20034abd9b..0c26f8a719c0 100644
--- a/mlir/test/lib/Dialect/Test/TestOps.td
+++ b/mlir/test/lib/Dialect/Test/TestOps.td
@@ -1163,6 +1163,16 @@ def TestTypeConsumerOp : TEST_Op<"type_consumer">,
def TestValidOp : TEST_Op<"valid", [Terminator]>,
Arguments<(ins Variadic<AnyType>)>;
+def TestMergeBlocksOp : TEST_Op<"merge_blocks"> {
+ let summary = "merge_blocks operation";
+ let description = [{
+ Test op with multiple blocks that are merged with Dialect Conversion"
+ }];
+
+ let regions = (region AnyRegion:$body);
+ let results = (outs Variadic<AnyType>:$result);
+}
+
//===----------------------------------------------------------------------===//
// Test parser.
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 5bc947fc8c91..f6607a5f5524 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -834,6 +834,118 @@ struct TestTypeConversionDriver
};
} // end anonymous namespace
+namespace {
+/// A rewriter pattern that tests that blocks can be merged.
+struct TestMergeBlock : public OpConversionPattern<TestMergeBlocksOp> {
+ using OpConversionPattern<TestMergeBlocksOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(TestMergeBlocksOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ Block &firstBlock = op.body().front();
+ Operation *branchOp = firstBlock.getTerminator();
+ Block *secondBlock = &*(std::next(op.body().begin()));
+ auto succOperands = branchOp->getOperands();
+ SmallVector<Value, 2> replacements(succOperands);
+ rewriter.eraseOp(branchOp);
+ rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
+ rewriter.updateRootInPlace(op, [] {});
+ return success();
+ }
+};
+
+/// A rewrite pattern to tests the undo mechanism of blocks being merged.
+struct TestUndoBlocksMerge : public ConversionPattern {
+ TestUndoBlocksMerge(MLIRContext *ctx)
+ : ConversionPattern("test.undo_blocks_merge", /*benefit=*/1, ctx) {}
+ LogicalResult
+ matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ Block &firstBlock = op->getRegion(0).front();
+ Operation *branchOp = firstBlock.getTerminator();
+ Block *secondBlock = &*(std::next(op->getRegion(0).begin()));
+ rewriter.setInsertionPointToStart(secondBlock);
+ rewriter.create<ILLegalOpF>(op->getLoc(), rewriter.getF32Type());
+ auto succOperands = branchOp->getOperands();
+ SmallVector<Value, 2> replacements(succOperands);
+ rewriter.eraseOp(branchOp);
+ rewriter.mergeBlocks(secondBlock, &firstBlock, replacements);
+ rewriter.updateRootInPlace(op, [] {});
+ return success();
+ }
+};
+
+/// A rewrite mechanism to inline the body of the op into its parent, when both
+/// ops can have a single block.
+struct TestMergeSingleBlockOps
+ : public OpConversionPattern<SingleBlockImplicitTerminatorOp> {
+ using OpConversionPattern<
+ SingleBlockImplicitTerminatorOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(SingleBlockImplicitTerminatorOp op, ArrayRef<Value> operands,
+ ConversionPatternRewriter &rewriter) const final {
+ SingleBlockImplicitTerminatorOp parentOp =
+ op.getParentOfType<SingleBlockImplicitTerminatorOp>();
+ if (!parentOp)
+ return failure();
+ Block &parentBlock = parentOp.region().front();
+ Block &innerBlock = op.region().front();
+ TerminatorOp innerTerminator =
+ cast<TerminatorOp>(innerBlock.getTerminator());
+ Block *parentPrologue =
+ rewriter.splitBlock(&parentBlock, Block::iterator(op));
+ rewriter.eraseOp(innerTerminator);
+ rewriter.mergeBlocks(&innerBlock, &parentBlock, {});
+ rewriter.eraseOp(op);
+ rewriter.mergeBlocks(parentPrologue, &parentBlock, {});
+ rewriter.updateRootInPlace(op, [] {});
+ return success();
+ }
+};
+
+struct TestMergeBlocksPatternDriver
+ : public PassWrapper<TestMergeBlocksPatternDriver,
+ OperationPass<ModuleOp>> {
+ void runOnOperation() override {
+ mlir::OwningRewritePatternList patterns;
+ MLIRContext *context = &getContext();
+ patterns
+ .insert<TestMergeBlock, TestUndoBlocksMerge, TestMergeSingleBlockOps>(
+ context);
+ ConversionTarget target(*context);
+ target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp, TerminatorOp,
+ TestBranchOp, TestTypeConsumerOp, TestTypeProducerOp,
+ TestReturnOp>();
+ target.addIllegalOp<ILLegalOpF>();
+
+ /// Expect the op to have a single block after legalization.
+ target.addDynamicallyLegalOp<TestMergeBlocksOp>(
+ [&](TestMergeBlocksOp op) -> bool {
+ return llvm::hasSingleElement(op.body());
+ });
+
+ /// Only allow `test.br` within test.merge_blocks op.
+ target.addDynamicallyLegalOp<TestBranchOp>([&](TestBranchOp op) -> bool {
+ return op.getParentOfType<TestMergeBlocksOp>();
+ });
+
+ /// Expect that all nested test.SingleBlockImplicitTerminator ops are
+ /// inlined.
+ target.addDynamicallyLegalOp<SingleBlockImplicitTerminatorOp>(
+ [&](SingleBlockImplicitTerminatorOp op) -> bool {
+ return !op.getParentOfType<SingleBlockImplicitTerminatorOp>();
+ });
+
+ DenseSet<Operation *> unlegalizedOps;
+ (void)applyPartialConversion(getOperation(), target, patterns,
+ &unlegalizedOps);
+ for (auto *op : unlegalizedOps)
+ op->emitRemark() << "op '" << op->getName() << "' is not legalizable";
+ }
+};
+} // namespace
+
//===----------------------------------------------------------------------===//
// PassRegistration
//===----------------------------------------------------------------------===//
@@ -866,5 +978,9 @@ void registerPatternsTestPass() {
PassRegistration<TestTypeConversionDriver>(
"test-legalize-type-conversion",
"Test various type conversion functionalities in DialectConversion");
+
+ PassRegistration<TestMergeBlocksPatternDriver>{
+ "test-merge-blocks",
+ "Test Merging operation in ConversionPatternRewriter"};
}
} // namespace mlir
More information about the Mlir-commits
mailing list