[Mlir-commits] [mlir] [mlir][Transforms] Support `moveOpBefore`/`After` in dialect conversion (PR #81240)

Matthias Springer llvmlistbot at llvm.org
Wed Feb 14 08:33:56 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/81240

>From 5e261def657fc6ac280ab9af8d1d57208efe9aad Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Wed, 14 Feb 2024 16:08:38 +0000
Subject: [PATCH] [mlir][Transforms] Support `moveOpBefore`/`After` in dialect
 conversion

Add a new rewrite action for "operation movements". This action can roll back `moveOpBefore` and `moveOpAfter`.

`RewriterBase::moveOpBefore` and `RewriterBase::moveOpAfter` is no longer virtual. (The dialect conversion can gather all required information for rollbacks from listener notifications.)

BEGIN_PUBLIC
No public commit message needed for presubmit.
END_PUBLIC
---
 mlir/include/mlir/IR/PatternMatch.h           |  6 +-
 .../mlir/Transforms/DialectConversion.h       |  9 +--
 .../Transforms/Utils/DialectConversion.cpp    | 74 +++++++++++++++----
 mlir/test/Transforms/test-legalizer.mlir      | 14 ++++
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   | 20 ++++-
 5 files changed, 95 insertions(+), 28 deletions(-)

diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 78dcfe7f6fc3d2..b8aeea0d23475b 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -588,8 +588,7 @@ class RewriterBase : public OpBuilder {
 
   /// Unlink this operation from its current block and insert it right before
   /// `iterator` in the specified block.
-  virtual void moveOpBefore(Operation *op, Block *block,
-                            Block::iterator iterator);
+  void moveOpBefore(Operation *op, Block *block, Block::iterator iterator);
 
   /// Unlink this operation from its current block and insert it right after
   /// `existingOp` which may be in the same or another block in the same
@@ -598,8 +597,7 @@ class RewriterBase : public OpBuilder {
 
   /// Unlink this operation from its current block and insert it right after
   /// `iterator` in the specified block.
-  virtual void moveOpAfter(Operation *op, Block *block,
-                           Block::iterator iterator);
+  void moveOpAfter(Operation *op, Block *block, Block::iterator iterator);
 
   /// Unlink this block and insert it right before `existingBlock`.
   void moveBlockBefore(Block *block, Block *anotherBlock);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 851d639ae68a77..15fa39bde104b9 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -744,8 +744,8 @@ class ConversionPatternRewriter final : public PatternRewriter {
 
   /// PatternRewriter hook for updating the given operation in-place.
   /// Note: These methods only track updates to the given operation itself,
-  /// and not nested regions. Updates to regions will still require notification
-  /// through other more specific hooks above.
+  /// and not nested regions. Updates to regions will still require
+  /// notification through other more specific hooks above.
   void startOpModification(Operation *op) override;
 
   /// PatternRewriter hook for updating the given operation in-place.
@@ -761,11 +761,6 @@ class ConversionPatternRewriter final : public PatternRewriter {
   // Hide unsupported pattern rewriter API.
   using OpBuilder::setListener;
 
-  void moveOpBefore(Operation *op, Block *block,
-                    Block::iterator iterator) override;
-  void moveOpAfter(Operation *op, Block *block,
-                   Block::iterator iterator) override;
-
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index 9875f8668b65a8..84597fb7986b07 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -760,7 +760,8 @@ class IRRewrite {
     InlineBlock,
     MoveBlock,
     SplitBlock,
-    BlockTypeConversion
+    BlockTypeConversion,
+    MoveOperation
   };
 
   virtual ~IRRewrite() = default;
@@ -982,6 +983,54 @@ class BlockTypeConversionRewrite : public BlockRewrite {
   // `ArgConverter::applyRewrites`. This should be done in the "commit" method.
   void rollback() override;
 };
+
+/// An operation rewrite.
+class OperationRewrite : public IRRewrite {
+public:
+  /// Return the operation that this rewrite operates on.
+  Operation *getOperation() const { return op; }
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() >= Kind::MoveOperation &&
+           rewrite->getKind() <= Kind::MoveOperation;
+  }
+
+protected:
+  OperationRewrite(Kind kind, ConversionPatternRewriterImpl &rewriterImpl,
+                   Operation *op)
+      : IRRewrite(kind, rewriterImpl), op(op) {}
+
+  // The operation that this rewrite operates on.
+  Operation *op;
+};
+
+/// Moving of an operation. This rewrite is immediately reflected in the IR.
+class MoveOperationRewrite : public OperationRewrite {
+public:
+  MoveOperationRewrite(ConversionPatternRewriterImpl &rewriterImpl,
+                       Operation *op, Block *block, Operation *insertBeforeOp)
+      : OperationRewrite(Kind::MoveOperation, rewriterImpl, op), block(block),
+        insertBeforeOp(insertBeforeOp) {}
+
+  static bool classof(const IRRewrite *rewrite) {
+    return rewrite->getKind() == Kind::MoveOperation;
+  }
+
+  void rollback() override {
+    // Move the operation back to its original position.
+    Block::iterator before =
+        insertBeforeOp ? Block::iterator(insertBeforeOp) : block->end();
+    block->getOperations().splice(before, op->getBlock()->getOperations(), op);
+  }
+
+private:
+  // The block in which this operation was previously contained.
+  Block *block;
+
+  // The original successor of this operation before it was moved. "nullptr" if
+  // this operation was the only operation in the region.
+  Operation *insertBeforeOp;
+};
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -1478,12 +1527,19 @@ LogicalResult ConversionPatternRewriterImpl::convertNonEntryRegionTypes(
 
 void ConversionPatternRewriterImpl::notifyOperationInserted(
     Operation *op, OpBuilder::InsertPoint previous) {
-  assert(!previous.isSet() && "expected newly created op");
   LLVM_DEBUG({
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
-  createdOps.push_back(op);
+  if (!previous.isSet()) {
+    // This is a newly created op.
+    createdOps.push_back(op);
+    return;
+  }
+  Operation *prevOp = previous.getPoint() == previous.getBlock()->end()
+                          ? nullptr
+                          : &*previous.getPoint();
+  appendRewrite<MoveOperationRewrite>(op, previous.getBlock(), prevOp);
 }
 
 void ConversionPatternRewriterImpl::notifyOpReplaced(Operation *op,
@@ -1722,18 +1778,6 @@ void ConversionPatternRewriter::cancelOpModification(Operation *op) {
   rootUpdates.erase(rootUpdates.begin() + updateIdx);
 }
 
-void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
-                                             Block::iterator iterator) {
-  llvm_unreachable(
-      "moving single ops is not supported in a dialect conversion");
-}
-
-void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
-                                            Block::iterator iterator) {
-  llvm_unreachable(
-      "moving single ops is not supported in a dialect conversion");
-}
-
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
   return *impl;
 }
diff --git a/mlir/test/Transforms/test-legalizer.mlir b/mlir/test/Transforms/test-legalizer.mlir
index d8cf6e4719cede..84fcc18ab7d370 100644
--- a/mlir/test/Transforms/test-legalizer.mlir
+++ b/mlir/test/Transforms/test-legalizer.mlir
@@ -320,3 +320,17 @@ module {
     return
   }
 }
+
+// -----
+
+// CHECK-LABEL: func @test_move_op_before_rollback()
+func.func @test_move_op_before_rollback() {
+  // CHECK: "test.one_region_op"()
+  // CHECK: "test.hoist_me"()
+  "test.one_region_op"() ({
+    // expected-remark @below{{'test.hoist_me' is not legalizable}}
+    %0 = "test.hoist_me"() : () -> (i32)
+    "test.valid"(%0) : (i32) -> ()
+  }) : () -> ()
+  "test.return"() : () -> ()
+}
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d7e5d6db50c1fb..1c02232b8adbb1 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -773,6 +773,22 @@ struct TestUndoBlockArgReplace : public ConversionPattern {
   }
 };
 
+/// This pattern hoists ops out of a "test.hoist_me" and then fails conversion.
+/// This is to test the rollback logic.
+struct TestUndoMoveOpBefore : public ConversionPattern {
+  TestUndoMoveOpBefore(MLIRContext *ctx)
+      : ConversionPattern("test.hoist_me", /*benefit=*/1, ctx) {}
+
+  LogicalResult
+  matchAndRewrite(Operation *op, ArrayRef<Value> operands,
+                  ConversionPatternRewriter &rewriter) const override {
+    rewriter.moveOpBefore(op, op->getParentOp());
+    // Replace with an illegal op to ensure the conversion fails.
+    rewriter.replaceOpWithNewOp<ILLegalOpF>(op, rewriter.getF32Type());
+    return success();
+  }
+};
+
 /// A rewrite pattern that tests the undo mechanism when erasing a block.
 struct TestUndoBlockErase : public ConversionPattern {
   TestUndoBlockErase(MLIRContext *ctx)
@@ -1069,7 +1085,7 @@ struct TestLegalizePatternDriver
              TestChangeProducerTypeF32ToInvalid, TestUpdateConsumerType,
              TestNonRootReplacement, TestBoundedRecursiveRewrite,
              TestNestedOpCreationUndoRewrite, TestReplaceEraseOp,
-             TestCreateUnregisteredOp>(&getContext());
+             TestCreateUnregisteredOp, TestUndoMoveOpBefore>(&getContext());
     patterns.add<TestDropOpSignatureConversion>(&getContext(), converter);
     mlir::populateAnyFunctionOpInterfaceTypeConversionPattern(patterns,
                                                               converter);
@@ -1079,7 +1095,7 @@ struct TestLegalizePatternDriver
     ConversionTarget target(getContext());
     target.addLegalOp<ModuleOp>();
     target.addLegalOp<LegalOpA, LegalOpB, LegalOpC, TestCastOp, TestValidOp,
-                      TerminatorOp>();
+                      TerminatorOp, OneRegionOp>();
     target
         .addIllegalOp<ILLegalOpF, TestRegionBuilderOp, TestOpWithRegionFold>();
     target.addDynamicallyLegalOp<TestReturnOp>([](TestReturnOp op) {



More information about the Mlir-commits mailing list