[Mlir-commits] [mlir] [mlir][IR] Add rewriter API for moving operations (PR #78988)

Matthias Springer llvmlistbot at llvm.org
Mon Jan 22 06:57:06 PST 2024


https://github.com/matthias-springer created https://github.com/llvm/llvm-project/pull/78988

The pattern rewriter documentation states that "*all* IR mutations [...] are required to be performed via the `PatternRewriter`." This commit adds two functions that were missing from the rewriter API: `moveOpBefore` and `moveOpAfter`.

After an operation was moved, the `notifyOperationInserted` callback is triggered. This allows listeners such as the greedy pattern rewrite driver to react to IR changes.

This commit narrows the discrepancy between the kind of IR modification that can be performed and the kind of IR modifications that can be listened to.


>From 8a2015b7b6c3bc85b601bc398ddd2013f0ebf16b Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 22 Jan 2024 14:55:57 +0000
Subject: [PATCH] [mlir][IR] Add rewriter API for moving operations

The pattern rewriter documentation states that "*all* IR mutations [...] are required to be performed via the `PatternRewriter`." This commit adds two functions that were missing from the rewriter API: `moveOpBefore` and `moveOpAfter`.

After an operation was moved, the `notifyOperationInserted` callback is triggered. This may cause listeners such as the greedy pattern rewrite driver to put the op back on the worklist.
---
 .../toy/Ch5/mlir/LowerToAffineLoops.cpp       |  4 ++--
 .../toy/Ch6/mlir/LowerToAffineLoops.cpp       |  4 ++--
 .../toy/Ch7/mlir/LowerToAffineLoops.cpp       |  4 ++--
 mlir/include/mlir/IR/Builders.h               | 14 ++++++++----
 mlir/include/mlir/IR/PatternMatch.h           | 22 +++++++++++++++++++
 .../mlir/Transforms/DialectConversion.h       |  6 +++++
 .../lib/Dialect/SCF/Transforms/ForToWhile.cpp |  2 +-
 .../BufferizableOpInterfaceImpl.cpp           |  2 +-
 mlir/lib/IR/PatternMatch.cpp                  | 22 +++++++++++++++++++
 .../Transforms/Utils/DialectConversion.cpp    | 12 ++++++++++
 .../Utils/LoopInvariantCodeMotionUtils.cpp    |  4 ++--
 mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp    |  4 ++--
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  4 +---
 13 files changed, 85 insertions(+), 19 deletions(-)

diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53..948c8a045e341e 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
 
   // Make sure to allocate at the beginning of the block.
   auto *parentBlock = alloc->getBlock();
-  alloc->moveBefore(&parentBlock->front());
+  rewriter.moveOpBefore(alloc, &parentBlock->front());
 
   // Make sure to deallocate this alloc at the end of the block. This is fine
   // as toy functions have no control flow.
   auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
-  dealloc->moveBefore(&parentBlock->back());
+  rewriter.moveOpBefore(dealloc, &parentBlock->back());
   return alloc;
 }
 
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53..948c8a045e341e 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
 
   // Make sure to allocate at the beginning of the block.
   auto *parentBlock = alloc->getBlock();
-  alloc->moveBefore(&parentBlock->front());
+  rewriter.moveOpBefore(alloc, &parentBlock->front());
 
   // Make sure to deallocate this alloc at the end of the block. This is fine
   // as toy functions have no control flow.
   auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
-  dealloc->moveBefore(&parentBlock->back());
+  rewriter.moveOpBefore(dealloc, &parentBlock->back());
   return alloc;
 }
 
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53..948c8a045e341e 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
 
   // Make sure to allocate at the beginning of the block.
   auto *parentBlock = alloc->getBlock();
-  alloc->moveBefore(&parentBlock->front());
+  rewriter.moveOpBefore(alloc, &parentBlock->front());
 
   // Make sure to deallocate this alloc at the end of the block. This is fine
   // as toy functions have no control flow.
   auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
-  dealloc->moveBefore(&parentBlock->back());
+  rewriter.moveOpBefore(dealloc, &parentBlock->back());
   return alloc;
 }
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 13fbc3fb928c39..7b9e40e245c713 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -285,12 +285,18 @@ class OpBuilder : public Builder {
 
     virtual ~Listener() = default;
 
-    /// Notification handler for when an operation is inserted into the builder.
-    /// `op` is the operation that was inserted.
+    /// Notify the listener that the specified operation was inserted.
+    ///
+    /// Note: Creating an (unlinked) op does not trigger this notification.
+    /// Only when the op is inserted, this notification is triggered. This
+    /// notification is also triggered when moving an operation to a different
+    /// location.
+    // TODO: If needed, the previous location of the operation could be passed
+    // as a parameter. This would also allow listeners to distinguish between
+    // "newly created op was inserted" and "existing op was moved".
     virtual void notifyOperationInserted(Operation *op) {}
 
-    /// Notification handler for when a block is created using the builder.
-    /// `block` is the block that was created.
+    /// Notify the listener that the specified block was inserted.
     virtual void notifyBlockCreated(Block *block) {}
 
   protected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 815340c9185093..db95f7243e178c 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder {
 
     /// Notify the listener that the specified operation is about to be erased.
     /// At this point, the operation has zero uses.
+    ///
+    /// Note: This notification is not triggered when unlinking an operation.
     virtual void notifyOperationRemoved(Operation *op) {}
 
     /// Notify the listener that the pattern failed to match the given
@@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder {
   /// block into a new block, and return it.
   virtual Block *splitBlock(Block *block, Block::iterator before);
 
+  /// Unlink this operation from its current block and insert it right before
+  /// `existingOp` which may be in the same or another block in the same
+  /// function.
+  void moveOpBefore(Operation *op, Operation *existingOp);
+
+  /// Unlink this operation from its current block and insert it right before
+  /// `iterator` in the specified block.
+  virtual void moveOpBefore(Operation *op, Block *block,
+                            Block::iterator iterator);
+
+  /// Unlink this operation from its current block and insert it right after
+  /// `existingOp` which may be in the same or another block in the same
+  /// function.
+  void moveOpAfter(Operation *op, Operation *existingOp);
+
+  /// Unlink this operation from its current block and insert it right after
+  /// `iterator` in the specified block.
+  virtual void moveOpAfter(Operation *op, Block *block,
+                           Block::iterator iterator);
+
   /// This method is used to notify the rewriter that an in-place operation
   /// modification is about to happen. A call to this function *must* be
   /// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9568540789df3f..7dc07e5b05e61a 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter,
   detail::ConversionPatternRewriterImpl &getImpl();
 
 private:
+  // Hide unsupported pattern rewriter API.
   using OpBuilder::getListener;
   using OpBuilder::setListener;
 
+  void moveOpBefore(Operation *op, Block *block,
+                    Block::iterator iterator) override;
+  void moveOpAfter(Operation *op, Block *block,
+                   Block::iterator iterator) override;
+
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index cda561b1d1054d..9f8189ae15e6de 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
 
     // Inline for-loop body operations into 'after' region.
     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
-      arg.moveBefore(afterBlock, afterBlock->end());
+      rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
 
     // Add incremented IV to yield operations
     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 2cd57e7324b4dc..678b7c099fa369 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
     for (Operation *user : srcBuffer->getUsers()) {
       if (hasEffect<MemoryEffects::Free>(user)) {
         if (user->getBlock() == parallelCombiningParent->getBlock())
-          user->moveBefore(user->getBlock()->getTerminator());
+          rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
         break;
       }
     }
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index ba0516e0539b6c..2acc1629ddac0a 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -366,3 +366,25 @@ void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
 void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
   cloneRegionBefore(region, *before->getParent(), before->getIterator());
 }
+
+void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
+  moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
+}
+
+void RewriterBase::moveOpBefore(Operation *op, Block *block,
+                                Block::iterator iterator) {
+  op->moveBefore(block, iterator);
+  if (listener)
+    listener->notifyOperationInserted(op);
+}
+
+void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
+  moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
+}
+
+void RewriterBase::moveOpAfter(Operation *op, Block *block,
+                               Block::iterator iterator) {
+  op->moveAfter(block, iterator);
+  if (listener)
+    listener->notifyOperationInserted(op);
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ef6a49455d1860..0187436f700fd7 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1651,6 +1651,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
   return impl->notifyMatchFailure(loc, reasonCallback);
 }
 
+void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
+                                             Block::iterator iterator) {
+  llvm_unreachable(
+      "moving single ops is not supported in a dialect conversion");
+}
+
+void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
+                                            Block::iterator iterator) {
+  llvm_unreachable(
+      "moving single ops is not supported in a dialect conversion");
+}
+
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
   return *impl;
 }
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 8f97fd3d9ddf84..66ce6067963f83 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
       iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
       OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
       OpResult newLoopResult = loopLike.getLoopResults()->back();
-      extractionOp->moveBefore(loopLike);
-      insertionOp->moveAfter(loopLike);
+      rewriter.moveOpBefore(extractionOp, loopLike);
+      rewriter.moveOpAfter(insertionOp, loopLike);
       rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
                                   insertionOp.getDestinationOperand().get());
       extractionOp.getSourceOperand().set(
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index a8a808424b690f..8a92d840ad1302 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
     auto ifOp =
         rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
     // True branch.
-    op->moveBefore(&ifOp.getThenRegion().front(),
-                   ifOp.getThenRegion().front().begin());
+    rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
+                          ifOp.getThenRegion().front().begin());
     rewriter.setInsertionPointAfter(op);
     if (op->getNumResults() > 0)
       rewriter.create<scf::YieldOp>(loc, op->getResults());
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d1ac5e81e75a69..89b9d1ce78a52b 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
       return failure();
     if (!toBeHoisted->hasAttr("eligible"))
       return failure();
-    // Hoisting means removing an op from the enclosing op. I.e., the enclosing
-    // op is modified.
-    rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
+    rewriter.moveOpBefore(toBeHoisted, op);
     return success();
   }
 };



More information about the Mlir-commits mailing list