[Mlir-commits] [mlir] f27f1e8 - [mlir] DialectConversion: support block creation in ConversionPatternRewriter

Alex Zinenko llvmlistbot at llvm.org
Fri Apr 3 11:30:36 PDT 2020


Author: Alex Zinenko
Date: 2020-04-03T20:30:03+02:00
New Revision: f27f1e8c27b1d7cf624877e798999244a72adb41

URL: https://github.com/llvm/llvm-project/commit/f27f1e8c27b1d7cf624877e798999244a72adb41
DIFF: https://github.com/llvm/llvm-project/commit/f27f1e8c27b1d7cf624877e798999244a72adb41.diff

LOG: [mlir] DialectConversion: support block creation in ConversionPatternRewriter

PatternRewriter and derived classes provide a set of virtual methods to
manipulate blocks, which ConversionPatternRewriter overrides to keep track of
the manipulations and undo them in case the conversion fails. However, one can
currently create a block only by splitting another block into two. This not
only makes the API inconsistent (`splitBlock` is allowed in conversion
patterns, but `createBlock` is not), but it also make it impossible for one to
create blocks with argument lists different from those of already existing
blocks since in-place block updates are not supported either. Such
functionality precludes dialect conversion infrastructure from being used more
extensively on region-containing ops, for example, for value-returning "if"
operations. At the same time, ConversionPatternRewriter already allows one to
undo block creation as block creation is one of the primitive operations in
already supported region inlining.

Support block creation in conversion patterns by hooking `createBlock` on the
block action undo mechanism. This requires to make `Builder::createBlock`
virtual, similarly to Op insertion. This is a minimal change to the Builder
infrastructure that will later help support additional use cases such as block
signature changes. `createBlock` now additionally takes the types of the block
arguments that are added immediately so as to avoid in-place argument list
manipulation that would be illegal in conversion patterns.

Added: 
    

Modified: 
    mlir/include/mlir/IR/Builders.h
    mlir/include/mlir/Transforms/DialectConversion.h
    mlir/lib/IR/Builders.cpp
    mlir/lib/Transforms/DialectConversion.cpp
    mlir/test/Transforms/test-legalizer.mlir
    mlir/test/lib/Dialect/Test/TestPatterns.cpp

Removed: 
    


################################################################################
diff  --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 1c6b16f22989..75f49e86d10a 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -298,13 +298,15 @@ class OpBuilder : public Builder {
   /// Insert the given operation at the current insertion point and return it.
   virtual Operation *insert(Operation *op);
 
-  /// Add new block and set the insertion point to the end of it. The block is
-  /// inserted at the provided insertion point of 'parent'.
-  Block *createBlock(Region *parent, Region::iterator insertPt = {});
-
-  /// Add new block and set the insertion point to the end of it. The block is
-  /// placed before 'insertBefore'.
-  Block *createBlock(Block *insertBefore);
+  /// Add new block with 'argTypes' arguments and set the insertion point to the
+  /// end of it. The block is inserted at the provided insertion point of
+  /// 'parent'.
+  virtual Block *createBlock(Region *parent, Region::iterator insertPt = {},
+                             TypeRange argTypes = llvm::None);
+
+  /// Add new block with 'argTypes' arguments and set the insertion point to the
+  /// end of it. The block is placed before 'insertBefore'.
+  Block *createBlock(Block *insertBefore, TypeRange argTypes = llvm::None);
 
   /// Returns the current block of the builder.
   Block *getBlock() const { return block; }

diff  --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 776007347c5e..9ab3a715e0ab 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -344,6 +344,10 @@ class ConversionPatternRewriter final : public PatternRewriter {
   /// otherwise an assert will be issued.
   void eraseOp(Operation *op) override;
 
+  /// PatternRewriter hook for creating a new block with the given arguments.
+  Block *createBlock(Region *parent, Region::iterator insertPt = {},
+                     TypeRange argTypes = llvm::None) override;
+
   /// PatternRewriter hook for splitting a block into two parts.
   Block *splitBlock(Block *block, Block::iterator before) override;
 

diff  --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index 23536651f974..c8d5ea6b6ca9 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -339,24 +339,28 @@ Operation *OpBuilder::insert(Operation *op) {
   return op;
 }
 
-/// Add new block and set the insertion point to the end of it. The block is
-/// inserted at the provided insertion point of 'parent'.
-Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt) {
+/// Add new block with 'argTypes' arguments and set the insertion point to the
+/// end of it. The block is inserted at the provided insertion point of
+/// 'parent'.
+Block *OpBuilder::createBlock(Region *parent, Region::iterator insertPt,
+                              TypeRange argTypes) {
   assert(parent && "expected valid parent region");
   if (insertPt == Region::iterator())
     insertPt = parent->end();
 
   Block *b = new Block();
+  b->addArguments(argTypes);
   parent->getBlocks().insert(insertPt, b);
   setInsertionPointToEnd(b);
   return b;
 }
 
-/// Add new block and set the insertion point to the end of it.  The block is
-/// placed before 'insertBefore'.
-Block *OpBuilder::createBlock(Block *insertBefore) {
+/// Add new block with 'argTypes' arguments and set the insertion point to the
+/// end of it.  The block is placed before 'insertBefore'.
+Block *OpBuilder::createBlock(Block *insertBefore, TypeRange argTypes) {
   assert(insertBefore && "expected valid insertion block");
-  return createBlock(insertBefore->getParent(), Region::iterator(insertBefore));
+  return createBlock(insertBefore->getParent(), Region::iterator(insertBefore),
+                     argTypes);
 }
 
 /// Create an operation given the fields represented as an OperationState.

diff  --git a/mlir/lib/Transforms/DialectConversion.cpp b/mlir/lib/Transforms/DialectConversion.cpp
index 19304b3fb73f..725f5f4bb16e 100644
--- a/mlir/lib/Transforms/DialectConversion.cpp
+++ b/mlir/lib/Transforms/DialectConversion.cpp
@@ -585,6 +585,9 @@ struct ConversionPatternRewriterImpl {
   /// PatternRewriter hook for replacing the results of an operation.
   void replaceOp(Operation *op, ValueRange newValues);
 
+  /// Notifies that a block was created.
+  void notifyCreatedBlock(Block *block);
+
   /// Notifies that a block was split.
   void notifySplitBlock(Block *block, Block *continuation);
 
@@ -804,6 +807,10 @@ void ConversionPatternRewriterImpl::replaceOp(Operation *op,
   markNestedOpsIgnored(op);
 }
 
+void ConversionPatternRewriterImpl::notifyCreatedBlock(Block *block) {
+  blockActions.push_back(BlockAction::getCreate(block));
+}
+
 void ConversionPatternRewriterImpl::notifySplitBlock(Block *block,
                                                      Block *continuation) {
   blockActions.push_back(BlockAction::getSplit(continuation, block));
@@ -910,6 +917,15 @@ Value ConversionPatternRewriter::getRemappedValue(Value key) {
   return impl->mapping.lookupOrDefault(key);
 }
 
+/// PatternRewriter hook for creating a new block with the given arguments.
+Block *ConversionPatternRewriter::createBlock(Region *parent,
+                                              Region::iterator insertPtr,
+                                              TypeRange argTypes) {
+  Block *block = PatternRewriter::createBlock(parent, insertPtr, argTypes);
+  impl->notifyCreatedBlock(block);
+  return block;
+}
+
 /// PatternRewriter hook for splitting a block into two parts.
 Block *ConversionPatternRewriter::splitBlock(Block *block,
                                              Block::iterator before) {

diff  --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index bd73cf30639a..3305e017d5b3 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -130,6 +130,19 @@ func @remove_foldable_op(%arg0 : i32) -> (i32) {
   return %0 : i32
 }
 
+// CHECK-LABEL: @create_block
+func @create_block() {
+  "test.container"() ({
+    // Check that we created a block with arguments.
+    // CHECK-NOT: test.create_block
+    // CHECK: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
+    // CHECK: test.finish
+    "test.create_block"() : () -> ()
+    "test.finish"() : () -> ()
+  }) : () -> ()
+  return
+}
+
 // -----
 
 func @fail_to_convert_illegal_op() -> i32 {
@@ -163,3 +176,17 @@ func @fail_to_convert_region() {
   }) : () -> ()
   return
 }
+
+// -----
+
+// CHECK-LABEL: @create_illegal_block
+func @create_illegal_block() {
+  "test.container"() ({
+    // Check that we can undo block creation, i.e. that the block was removed.
+    // CHECK: test.create_illegal_block
+    // CHECK-NOT: ^{{.*}}(%{{.*}}: i32, %{{.*}}: i32):
+    "test.create_illegal_block"() : () -> ()
+    "test.finish"() : () -> ()
+  }) : () -> ()
+  return
+}

diff  --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index 0b73f09c1943..23d650e15479 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -183,6 +183,41 @@ struct TestRegionRewriteUndo : public RewritePattern {
     return success();
   }
 };
+/// A simple pattern that creates a block at the end of the parent region of the
+/// matched operation.
+struct TestCreateBlock : public RewritePattern {
+  TestCreateBlock(MLIRContext *ctx)
+      : RewritePattern("test.create_block", /*benefit=*/1, ctx) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const final {
+    Region &region = *op->getParentRegion();
+    Type i32Type = rewriter.getIntegerType(32);
+    rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
+    rewriter.create<TerminatorOp>(op->getLoc());
+    rewriter.replaceOp(op, {});
+    return success();
+  }
+};
+
+/// A simple pattern that creates a block containing an invalid operaiton in
+/// order to trigger the block creation undo mechanism.
+struct TestCreateIllegalBlock : public RewritePattern {
+  TestCreateIllegalBlock(MLIRContext *ctx)
+      : RewritePattern("test.create_illegal_block", /*benefit=*/1, ctx) {}
+
+  LogicalResult matchAndRewrite(Operation *op,
+                                PatternRewriter &rewriter) const final {
+    Region &region = *op->getParentRegion();
+    Type i32Type = rewriter.getIntegerType(32);
+    rewriter.createBlock(&region, region.end(), {i32Type, i32Type});
+    // Create an illegal op to ensure the conversion fails.
+    rewriter.create<ILLegalOpF>(op->getLoc(), i32Type);
+    rewriter.create<TerminatorOp>(op->getLoc());
+    rewriter.replaceOp(op, {});
+    return success();
+  }
+};
 
 //===----------------------------------------------------------------------===//
 // Type-Conversion Rewrite Testing
@@ -373,12 +408,12 @@ struct TestLegalizePatternDriver
     TestTypeConverter converter;
     mlir::OwningRewritePatternList patterns;
     populateWithGenerated(&getContext(), &patterns);
-    patterns
-        .insert<TestRegionRewriteBlockMovement, TestRegionRewriteUndo,
-                TestPassthroughInvalidOp, TestSplitReturnType,
-                TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
-                TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
-                TestNonRootReplacement>(&getContext());
+    patterns.insert<
+        TestRegionRewriteBlockMovement, TestRegionRewriteUndo, TestCreateBlock,
+        TestCreateIllegalBlock, TestPassthroughInvalidOp, TestSplitReturnType,
+        TestChangeProducerTypeI32ToF32, TestChangeProducerTypeF32ToF64,
+        TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
+        TestNonRootReplacement>(&getContext());
     patterns.insert<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateFuncOpTypeConversionPattern(patterns, &getContext(),
                                               converter);
@@ -388,7 +423,8 @@ struct TestLegalizePatternDriver
     // Define the conversion target used for the test.
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp, ModuleTerminatorOp>();
-    target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp>();
+    target.addLegalOp<LegalOpA, LegalOpB, TestCastOp, TestValidOp,
+                      TerminatorOp>();
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {


        


More information about the Mlir-commits mailing list