[llvm-branch-commits] [mlir] [mlir][Transforms] Support `moveOpBefore`/`After` in dialect conversion (PR #81240)

via llvm-branch-commits llvm-branch-commits at lists.llvm.org
Fri Feb 9 02:02:03 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`.

`RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.)


---
Full diff: https://github.com/llvm/llvm-project/pull/81240.diff


5 Files Affected:

- (modified) mlir/include/mlir/IR/PatternMatch.h (+2-4) 
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (-5) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+59-15) 
- (modified) mlir/test/Transforms/test-legalizer.mlir (+14) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+18-2) 


``````````diff
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 78dcfe7f6fc3d..b8aeea0d23475 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder {
 
   /// Unlink this operation from its current block and insert it right before
   /// `iterator` in the specified block.
-  virtual void moveOpBefore(Operation *op, Block *block,
-                            Block::iterator iterator);
+  void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
 
   /// Unlink this operation from its current block and insert it right after
   /// `existingOp` which may be in the same or another block in the same
@@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder {
 
   /// Unlink this operation from its current block and insert it right after
   /// `iterator` in the specified block.
-  virtual void moveOpAfter(Operation *op, Block *block,
-                           Block::iterator iterator);
+  void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
 
   /// Unlink this block and insert it right before `existingBlock`.
   void moveBlockBefore(Block *block, Block *anotherBlock);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index f061d761ecefb..c0c702a7d3482 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -738,11 +738,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
   // Hide unsupported pattern rewriter API.
   using OpBuilder::setListener;
 
-  void moveOpBefore(Operation *op, Block *block,
-                    Block::iterator iterator) override;
-  void moveOpAfter(Operation *op, Block *block,
-                   Block::iterator iterator) override;
-
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 44c107c8733f3..ffdb069f6e9b8 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -757,7 +757,8 @@ class RewriteAction {
     InlineBlock,
     MoveBlock,
     SplitBlock,
-    BlockTypeConversion
+    BlockTypeConversion,
+    MoveOperation
   };
 
   virtual ~RewriteAction() = default;
@@ -970,6 +971,54 @@ class BlockTypeConversionAction : public BlockAction {
 
   void rollback() override;
 };
+
+/// An operation rewrite.
+class OperationAction : public RewriteAction {
+public:
+  /// Return the operation that this action operates on.
+  Operation *getOperation() const { return op; }
+
+  static bool classof(const RewriteAction *action) {
+    return action->getKind() >= Kind::MoveOperation &&
+           action->getKind() <= Kind::MoveOperation;
+  }
+
+protected:
+  OperationAction(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+                  Operation *op)
+      : RewriteAction(kind, rewriterImpl), op(op) {}
+
+  // The operation that this action operates on.
+  Operation *op;
+};
+
+/// Rewrite action that represent the moving of a block.
+class MoveOperationAction : public OperationAction {
+public:
+  MoveOperationAction(ConversionPatternRewriterImpl &rewriterImpl,
+                      Operation *op, Block *block, Operation *insertBeforeOp)
+      : OperationAction(Kind::MoveOperation, rewriterImpl, op), block(block),
+        insertBeforeOp(insertBeforeOp) {}
+
+  static bool classof(const RewriteAction *action) {
+    return action->getKind() == Kind::MoveOperation;
+  }
+
+  void rollback() override {
+    // Move the operation back to its original position.
+    Block::iterator before =
+        insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
+    block->getOperations().splice(before, op->getBlock()->getOperations(), op);
+  }
+
+private:
+  // The block in which this operation was previously contained.
+  Block *block;
+
+  // The original successor of this operation before it was moved. "nullptr" if
+  // this operation was the only operation in the region.
+  Operation *insertBeforeOp;
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1468,12 +1517,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
 
 void ConversionPatternRewriterImpl::notifyOperationInserted(
     Operation *op, OpBuilder::InsertPoint previous) {
-  assert(!previous.isSet() && "expected newly created op");
   LLVM_DEBUG({
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
-  createdOps.push_back(op);
+  if (!previous.isSet()) {
+    // This is a newly created op.
+    createdOps.push_back(op);
+    return;
+  }
+  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
+                          ? nullptr
+                          : &*previous.getPoint();
+  appendRewriteAction<MoveOperationAction>(op, previous.getBlock(), prevOp);
 }
 
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
@@ -1712,18 +1768,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
   rootUpdates.erase(rootUpdates.begin() + updateIdx);
 }
 
-void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
-                                             Block::iterator iterator) {
-  llvm_unreachable(
-      "moving single ops is not supported in a dialect conversion");
-}
-
-void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
-                                            Block::iterator iterator) {
-  llvm_unreachable(
-      "moving single ops is not supported in a dialect conversion");
-}
-
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
   return *impl;
 }
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d8cf6e4719ced..84fcc18ab7d37 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -320,3 +320,17 @@ module {
     return
   }
 }
+
+// -----
+
+// CHECK-LABEL: func @test_move_op_before_rollback()
+func.func @test_move_op_before_rollback() {
+  // CHECK: "test.one_region_op"()
+  // CHECK: "test.hoist_me"()
+  "test.one_region_op"() ({
+    // expected-remark @below{{'test.hoist_me' is not legalizable}}
+    %0 = "test.hoist_me"() : () -> (i32)
+    "test.valid"(%0) : (i32) -> ()
+  }) : () -> ()
+  "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d7e5d6db50c1f..1c02232b8adbb 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
   }
 };
 
+/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
+/// This is to test the rollback logic.
+struct TestUndoMoveOpBefore : public ConversionPattern {
+  TestUndoMoveOpBefore(MLIRContext *ctx)
+      : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.moveOpBefore(op, op->getParentOp());
+    // Replace with an illegal op to ensure the conversion fails.
+    rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
+    return success();
+  }
+};
+
 /// A rewrite pattern that tests the undo mechanism when erasing a block.
 struct TestUndoBlockErase : public ConversionPattern {
   TestUndoBlockErase(MLIRContext *ctx)
@@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver
              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
              TestNonRootReplacement, TestBoundedRecursiveRewrite,
              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
-             TestCreateUnregisteredOp>(&getContext());
+             TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext());
     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
-                      TerminatorOp>();
+                      TerminatorOp, OneRegionOp>();
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {

``````````

</details>


https://github.com/llvm/llvm-project/pull/81240


More information about the llvm-branch-commits mailing list