[Mlir-commits] [mlir] [mlir][IR] Add rewriter API for moving operations (PR #78988)

llvmlistbot at llvm.org llvmlistbot at llvm.org
Mon Jan 22 06:57:37 PST 2024


llvmbot wrote:


<!--LLVM PR SUMMARY COMMENT-->

@llvm/pr-subscribers-mlir-tensor

Author: Matthias Springer (matthias-springer)

<details>
<summary>Changes</summary>

The pattern rewriter documentation states that "*all* IR mutations [...] are required to be performed via the `PatternRewriter`." This commit adds two functions that were missing from the rewriter API: `moveOpBefore` and `moveOpAfter`.

After an operation was moved, the `notifyOperationInserted` callback is triggered. This allows listeners such as the greedy pattern rewrite driver to react to IR changes.

This commit narrows the discrepancy between the kind of IR modification that can be performed and the kind of IR modifications that can be listened to.


---
Full diff: https://github.com/llvm/llvm-project/pull/78988.diff


13 Files Affected:

- (modified) mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp (+2-2) 
- (modified) mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp (+2-2) 
- (modified) mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp (+2-2) 
- (modified) mlir/include/mlir/IR/Builders.h (+10-4) 
- (modified) mlir/include/mlir/IR/PatternMatch.h (+22) 
- (modified) mlir/include/mlir/Transforms/DialectConversion.h (+6) 
- (modified) mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp (+1-1) 
- (modified) mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp (+1-1) 
- (modified) mlir/lib/IR/PatternMatch.cpp (+22) 
- (modified) mlir/lib/Transforms/Utils/DialectConversion.cpp (+12) 
- (modified) mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp (+2-2) 
- (modified) mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp (+2-2) 
- (modified) mlir/test/lib/Dialect/Test/TestPatterns.cpp (+1-3) 


``````````diff
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
 
   // Make sure to allocate at the beginning of the block.
   auto *parentBlock = alloc->getBlock();
-  alloc->moveBefore(&parentBlock->front());
+  rewriter.moveOpBefore(alloc, &parentBlock->front());
 
   // Make sure to deallocate this alloc at the end of the block. This is fine
   // as toy functions have no control flow.
   auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
-  dealloc->moveBefore(&parentBlock->back());
+  rewriter.moveOpBefore(dealloc, &parentBlock->back());
   return alloc;
 }
 
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
 
   // Make sure to allocate at the beginning of the block.
   auto *parentBlock = alloc->getBlock();
-  alloc->moveBefore(&parentBlock->front());
+  rewriter.moveOpBefore(alloc, &parentBlock->front());
 
   // Make sure to deallocate this alloc at the end of the block. This is fine
   // as toy functions have no control flow.
   auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
-  dealloc->moveBefore(&parentBlock->back());
+  rewriter.moveOpBefore(dealloc, &parentBlock->back());
   return alloc;
 }
 
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
 
   // Make sure to allocate at the beginning of the block.
   auto *parentBlock = alloc->getBlock();
-  alloc->moveBefore(&parentBlock->front());
+  rewriter.moveOpBefore(alloc, &parentBlock->front());
 
   // Make sure to deallocate this alloc at the end of the block. This is fine
   // as toy functions have no control flow.
   auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
-  dealloc->moveBefore(&parentBlock->back());
+  rewriter.moveOpBefore(dealloc, &parentBlock->back());
   return alloc;
 }
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 13fbc3fb928c399..7b9e40e245c713a 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -285,12 +285,18 @@ class OpBuilder : public Builder {
 
     virtual ~Listener() = default;
 
-    /// Notification handler for when an operation is inserted into the builder.
-    /// `op` is the operation that was inserted.
+    /// Notify the listener that the specified operation was inserted.
+    ///
+    /// Note: Creating an (unlinked) op does not trigger this notification.
+    /// Only when the op is inserted, this notification is triggered. This
+    /// notification is also triggered when moving an operation to a different
+    /// location.
+    // TODO: If needed, the previous location of the operation could be passed
+    // as a parameter. This would also allow listeners to distinguish between
+    // "newly created op was inserted" and "existing op was moved".
     virtual void notifyOperationInserted(Operation *op) {}
 
-    /// Notification handler for when a block is created using the builder.
-    /// `block` is the block that was created.
+    /// Notify the listener that the specified block was inserted.
     virtual void notifyBlockCreated(Block *block) {}
 
   protected:
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 815340c91850935..db95f7243e178c1 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder {
 
     /// Notify the listener that the specified operation is about to be erased.
     /// At this point, the operation has zero uses.
+    ///
+    /// Note: This notification is not triggered when unlinking an operation.
     virtual void notifyOperationRemoved(Operation *op) {}
 
     /// Notify the listener that the pattern failed to match the given
@@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder {
   /// block into a new block, and return it.
   virtual Block *splitBlock(Block *block, Block::iterator before);
 
+  /// Unlink this operation from its current block and insert it right before
+  /// `existingOp` which may be in the same or another block in the same
+  /// function.
+  void moveOpBefore(Operation *op, Operation *existingOp);
+
+  /// Unlink this operation from its current block and insert it right before
+  /// `iterator` in the specified block.
+  virtual void moveOpBefore(Operation *op, Block *block,
+                            Block::iterator iterator);
+
+  /// Unlink this operation from its current block and insert it right after
+  /// `existingOp` which may be in the same or another block in the same
+  /// function.
+  void moveOpAfter(Operation *op, Operation *existingOp);
+
+  /// Unlink this operation from its current block and insert it right after
+  /// `iterator` in the specified block.
+  virtual void moveOpAfter(Operation *op, Block *block,
+                           Block::iterator iterator);
+
   /// This method is used to notify the rewriter that an in-place operation
   /// modification is about to happen. A call to this function *must* be
   /// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9568540789df3f6..7dc07e5b05e61a1 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter,
   detail::ConversionPatternRewriterImpl &getImpl();
 
 private:
+  // Hide unsupported pattern rewriter API.
   using OpBuilder::getListener;
   using OpBuilder::setListener;
 
+  void moveOpBefore(Operation *op, Block *block,
+                    Block::iterator iterator) override;
+  void moveOpAfter(Operation *op, Block *block,
+                   Block::iterator iterator) override;
+
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index cda561b1d1054d9..9f8189ae15e6de2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
 
     // Inline for-loop body operations into 'after' region.
     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
-      arg.moveBefore(afterBlock, afterBlock->end());
+      rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
 
     // Add incremented IV to yield operations
     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 2cd57e7324b4dc5..678b7c099fa3692 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
     for (Operation *user : srcBuffer->getUsers()) {
       if (hasEffect<MemoryEffects::Free>(user)) {
         if (user->getBlock() == parallelCombiningParent->getBlock())
-          user->moveBefore(user->getBlock()->getTerminator());
+          rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
         break;
       }
     }
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index ba0516e0539b6ca..2acc1629ddac0a6 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -366,3 +366,25 @@ void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
 void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
   cloneRegionBefore(region, *before->getParent(), before->getIterator());
 }
+
+void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
+  moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
+}
+
+void RewriterBase::moveOpBefore(Operation *op, Block *block,
+                                Block::iterator iterator) {
+  op->moveBefore(block, iterator);
+  if (listener)
+    listener->notifyOperationInserted(op);
+}
+
+void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
+  moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
+}
+
+void RewriterBase::moveOpAfter(Operation *op, Block *block,
+                               Block::iterator iterator) {
+  op->moveAfter(block, iterator);
+  if (listener)
+    listener->notifyOperationInserted(op);
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ef6a49455d18605..0187436f700fd73 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1651,6 +1651,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
   return impl->notifyMatchFailure(loc, reasonCallback);
 }
 
+void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
+                                             Block::iterator iterator) {
+  llvm_unreachable(
+      "moving single ops is not supported in a dialect conversion");
+}
+
+void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
+                                            Block::iterator iterator) {
+  llvm_unreachable(
+      "moving single ops is not supported in a dialect conversion");
+}
+
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
   return *impl;
 }
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 8f97fd3d9ddf84e..66ce6067963f838 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
       iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
       OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
       OpResult newLoopResult = loopLike.getLoopResults()->back();
-      extractionOp->moveBefore(loopLike);
-      insertionOp->moveAfter(loopLike);
+      rewriter.moveOpBefore(extractionOp, loopLike);
+      rewriter.moveOpAfter(insertionOp, loopLike);
       rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
                                   insertionOp.getDestinationOperand().get());
       extractionOp.getSourceOperand().set(
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index a8a808424b690f6..8a92d840ad13026 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
     auto ifOp =
         rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
     // True branch.
-    op->moveBefore(&ifOp.getThenRegion().front(),
-                   ifOp.getThenRegion().front().begin());
+    rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
+                          ifOp.getThenRegion().front().begin());
     rewriter.setInsertionPointAfter(op);
     if (op->getNumResults() > 0)
       rewriter.create<scf::YieldOp>(loc, op->getResults());
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d1ac5e81e75a695..89b9d1ce78a52b6 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
       return failure();
     if (!toBeHoisted->hasAttr("eligible"))
       return failure();
-    // Hoisting means removing an op from the enclosing op. I.e., the enclosing
-    // op is modified.
-    rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
+    rewriter.moveOpBefore(toBeHoisted, op);
     return success();
   }
 };

``````````

</details>


https://github.com/llvm/llvm-project/pull/78988


More information about the Mlir-commits mailing list