[flang-commits] [flang] [mlir] [mlir][IR] Add rewriter API for moving operations (PR #78988)

Matthias Springer via flang-commits flang-commits at lists.llvm.org
Thu Jan 25 01:16:12 PST 2024


https://github.com/matthias-springer updated https://github.com/llvm/llvm-project/pull/78988

>From 6161d337181396db48740d411a7c28c95482d232 Mon Sep 17 00:00:00 2001
From: Matthias Springer <springerm at google.com>
Date: Mon, 22 Jan 2024 15:42:02 +0000
Subject: [PATCH] [mlir][IR] Add rewriter API for moving operations

The pattern rewriter documentation states that "*all* IR mutations [...] are required to be performed via the `PatternRewriter`." This commit adds two functions that were missing from the rewriter API: `moveOpBefore` and `moveOpAfter`.

After an operation was moved, the `notifyOperationInserted` callback is triggered. This may cause listeners such as the greedy pattern rewrite driver to put the op back on the worklist.
---
 .../flang/Optimizer/Builder/FIRBuilder.h      |  6 +++-
 .../HLFIR/Transforms/BufferizeHLFIR.cpp       |  7 +++--
 .../toy/Ch5/mlir/LowerToAffineLoops.cpp       |  4 +--
 .../toy/Ch6/mlir/LowerToAffineLoops.cpp       |  4 +--
 .../toy/Ch7/mlir/LowerToAffineLoops.cpp       |  4 +--
 mlir/include/mlir/IR/Builders.h               | 20 ++++++++-----
 mlir/include/mlir/IR/PatternMatch.h           | 26 +++++++++++++++--
 .../mlir/Transforms/DialectConversion.h       |  8 +++++-
 mlir/lib/Dialect/Affine/IR/AffineOps.cpp      |  4 +--
 .../Async/Transforms/AsyncParallelFor.cpp     |  4 +--
 .../Bufferization/Transforms/Bufferize.cpp    |  6 +++-
 .../TransformOps/LinalgTransformOps.cpp       |  8 ++++--
 .../lib/Dialect/SCF/Transforms/ForToWhile.cpp |  2 +-
 .../BufferizableOpInterfaceImpl.cpp           |  2 +-
 mlir/lib/IR/Builders.cpp                      |  4 +--
 mlir/lib/IR/PatternMatch.cpp                  | 28 +++++++++++++++++++
 .../Transforms/Utils/DialectConversion.cpp    | 18 ++++++++++--
 .../Utils/GreedyPatternRewriteDriver.cpp      | 18 ++++++++----
 .../Utils/LoopInvariantCodeMotionUtils.cpp    |  4 +--
 mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp    |  4 +--
 mlir/test/lib/Dialect/Test/TestPatterns.cpp   |  4 +--
 mlir/test/lib/IR/TestClone.cpp                |  3 +-
 mlir/test/lib/Transforms/TestConstantFold.cpp |  3 +-
 23 files changed, 144 insertions(+), 47 deletions(-)

diff --git a/flang/include/flang/Optimizer/Builder/FIRBuilder.h b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
index b5b2c99810b15bb..c9dcfde56e3829d 100644
--- a/flang/include/flang/Optimizer/Builder/FIRBuilder.h
+++ b/flang/include/flang/Optimizer/Builder/FIRBuilder.h
@@ -490,7 +490,11 @@ class FirOpBuilder : public mlir::OpBuilder, public mlir::OpBuilder::Listener {
   LLVM_DUMP_METHOD void dumpFunc();
 
   /// FirOpBuilder hook for creating new operation.
-  void notifyOperationInserted(mlir::Operation *op) override {
+  void notifyOperationInserted(mlir::Operation *op,
+                               mlir::OpBuilder::InsertPoint previous) override {
+    // We only care about newly created operations.
+    if (!previous.isSet())
+      return;
     setCommonAttributes(op);
   }
 
diff --git a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
index 641854bd201f0b0..5fe78b7408026f8 100644
--- a/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
+++ b/flang/lib/Optimizer/HLFIR/Transforms/BufferizeHLFIR.cpp
@@ -730,9 +730,10 @@ struct HLFIRListener : public mlir::OpBuilder::Listener {
   HLFIRListener(fir::FirOpBuilder &builder,
                 mlir::ConversionPatternRewriter &rewriter)
       : builder{builder}, rewriter{rewriter} {}
-  void notifyOperationInserted(mlir::Operation *op) override {
-    builder.notifyOperationInserted(op);
-    rewriter.notifyOperationInserted(op);
+  void notifyOperationInserted(mlir::Operation *op,
+                               mlir::OpBuilder::InsertPoint previous) override {
+    builder.notifyOperationInserted(op, previous);
+    rewriter.notifyOperationInserted(op, previous);
   }
   virtual void notifyBlockCreated(mlir::Block *block) override {
     builder.notifyBlockCreated(block);
diff --git a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch5/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
 
   // Make sure to allocate at the beginning of the block.
   auto *parentBlock = alloc->getBlock();
-  alloc->moveBefore(&parentBlock->front());
+  rewriter.moveOpBefore(alloc, &parentBlock->front());
 
   // Make sure to deallocate this alloc at the end of the block. This is fine
   // as toy functions have no control flow.
   auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
-  dealloc->moveBefore(&parentBlock->back());
+  rewriter.moveOpBefore(dealloc, &parentBlock->back());
   return alloc;
 }
 
diff --git a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch6/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
 
   // Make sure to allocate at the beginning of the block.
   auto *parentBlock = alloc->getBlock();
-  alloc->moveBefore(&parentBlock->front());
+  rewriter.moveOpBefore(alloc, &parentBlock->front());
 
   // Make sure to deallocate this alloc at the end of the block. This is fine
   // as toy functions have no control flow.
   auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
-  dealloc->moveBefore(&parentBlock->back());
+  rewriter.moveOpBefore(dealloc, &parentBlock->back());
   return alloc;
 }
 
diff --git a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
index ae4bd980c34b53b..948c8a045e341e4 100644
--- a/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
+++ b/mlir/examples/toy/Ch7/mlir/LowerToAffineLoops.cpp
@@ -60,12 +60,12 @@ static Value insertAllocAndDealloc(MemRefType type, Location loc,
 
   // Make sure to allocate at the beginning of the block.
   auto *parentBlock = alloc->getBlock();
-  alloc->moveBefore(&parentBlock->front());
+  rewriter.moveOpBefore(alloc, &parentBlock->front());
 
   // Make sure to deallocate this alloc at the end of the block. This is fine
   // as toy functions have no control flow.
   auto dealloc = rewriter.create<memref::DeallocOp>(loc, alloc);
-  dealloc->moveBefore(&parentBlock->back());
+  rewriter.moveOpBefore(dealloc, &parentBlock->back());
   return alloc;
 }
 
diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h
index 13fbc3fb928c399..6b95be7c6d372f8 100644
--- a/mlir/include/mlir/IR/Builders.h
+++ b/mlir/include/mlir/IR/Builders.h
@@ -205,6 +205,7 @@ class Builder {
 /// automatically inserted at an insertion point. The builder is copyable.
 class OpBuilder : public Builder {
 public:
+  class InsertPoint;
   struct Listener;
 
   /// Create a builder with the given context.
@@ -285,12 +286,17 @@ class OpBuilder : public Builder {
 
     virtual ~Listener() = default;
 
-    /// Notification handler for when an operation is inserted into the builder.
-    /// `op` is the operation that was inserted.
-    virtual void notifyOperationInserted(Operation *op) {}
-
-    /// Notification handler for when a block is created using the builder.
-    /// `block` is the block that was created.
+    /// Notify the listener that the specified operation was inserted.
+    ///
+    /// * If the operation was moved, then `previous` is the previous location
+    ///   of the op.
+    /// * If the operation was unlinked before it was inserted, then `previous`
+    ///   is empty.
+    ///
+    /// Note: Creating an (unlinked) op does not trigger this notification.
+    virtual void notifyOperationInserted(Operation *op, InsertPoint previous) {}
+
+    /// Notify the listener that the specified block was inserted.
     virtual void notifyBlockCreated(Block *block) {}
 
   protected:
@@ -517,7 +523,7 @@ class OpBuilder : public Builder {
     if (succeeded(tryFold(op, results)))
       op->erase();
     else if (listener)
-      listener->notifyOperationInserted(op);
+      listener->notifyOperationInserted(op, /*previous=*/{});
   }
 
   /// Overload to create or fold a single result operation.
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 815340c91850935..7f233cd3f4d4b3c 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -428,6 +428,8 @@ class RewriterBase : public OpBuilder {
 
     /// Notify the listener that the specified operation is about to be erased.
     /// At this point, the operation has zero uses.
+    ///
+    /// Note: This notification is not triggered when unlinking an operation.
     virtual void notifyOperationRemoved(Operation *op) {}
 
     /// Notify the listener that the pattern failed to match the given
@@ -450,8 +452,8 @@ class RewriterBase : public OpBuilder {
   struct ForwardingListener : public RewriterBase::Listener {
     ForwardingListener(OpBuilder::Listener *listener) : listener(listener) {}
 
-    void notifyOperationInserted(Operation *op) override {
-      listener->notifyOperationInserted(op);
+    void notifyOperationInserted(Operation *op, InsertPoint previous) override {
+      listener->notifyOperationInserted(op, previous);
     }
     void notifyBlockCreated(Block *block) override {
       listener->notifyBlockCreated(block);
@@ -591,6 +593,26 @@ class RewriterBase : public OpBuilder {
   /// block into a new block, and return it.
   virtual Block *splitBlock(Block *block, Block::iterator before);
 
+  /// Unlink this operation from its current block and insert it right before
+  /// `existingOp` which may be in the same or another block in the same
+  /// function.
+  void moveOpBefore(Operation *op, Operation *existingOp);
+
+  /// Unlink this operation from its current block and insert it right before
+  /// `iterator` in the specified block.
+  virtual void moveOpBefore(Operation *op, Block *block,
+                            Block::iterator iterator);
+
+  /// Unlink this operation from its current block and insert it right after
+  /// `existingOp` which may be in the same or another block in the same
+  /// function.
+  void moveOpAfter(Operation *op, Operation *existingOp);
+
+  /// Unlink this operation from its current block and insert it right after
+  /// `iterator` in the specified block.
+  virtual void moveOpAfter(Operation *op, Block *block,
+                           Block::iterator iterator);
+
   /// This method is used to notify the rewriter that an in-place operation
   /// modification is about to happen. A call to this function *must* be
   /// followed by a call to either `finalizeOpModification` or
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 9568540789df3f6..32c5937d014e9ef 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -737,7 +737,7 @@ class ConversionPatternRewriter final : public PatternRewriter,
   using PatternRewriter::cloneRegionBefore;
 
   /// PatternRewriter hook for inserting a new operation.
-  void notifyOperationInserted(Operation *op) override;
+  void notifyOperationInserted(Operation *op, InsertPoint previous) override;
 
   /// PatternRewriter hook for updating the given operation in-place.
   /// Note: These methods only track updates to the given operation itself,
@@ -761,9 +761,15 @@ class ConversionPatternRewriter final : public PatternRewriter,
   detail::ConversionPatternRewriterImpl &getImpl();
 
 private:
+  // Hide unsupported pattern rewriter API.
   using OpBuilder::getListener;
   using OpBuilder::setListener;
 
+  void moveOpBefore(Operation *op, Block *block,
+                    Block::iterator iterator) override;
+  void moveOpAfter(Operation *op, Block *block,
+                   Block::iterator iterator) override;
+
   std::unique_ptr<detail::ConversionPatternRewriterImpl> impl;
 };
 
diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
index c260e68d509e983..b802ae33edaccee 100644
--- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
+++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp
@@ -1206,7 +1206,7 @@ mlir::affine::makeComposedFoldedAffineApply(OpBuilder &b, Location loc,
   if (failed(applyOp->fold(constOperands, foldResults)) ||
       foldResults.empty()) {
     if (OpBuilder::Listener *listener = b.getListener())
-      listener->notifyOperationInserted(applyOp);
+      listener->notifyOperationInserted(applyOp, /*previous=*/{});
     return applyOp.getResult();
   }
 
@@ -1274,7 +1274,7 @@ static OpFoldResult makeComposedFoldedMinMax(OpBuilder &b, Location loc,
   if (failed(minMaxOp->fold(constOperands, foldResults)) ||
       foldResults.empty()) {
     if (OpBuilder::Listener *listener = b.getListener())
-      listener->notifyOperationInserted(minMaxOp);
+      listener->notifyOperationInserted(minMaxOp, /*previous=*/{});
     return minMaxOp.getResult();
   }
 
diff --git a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
index 428a3c945581b48..8c3e25355f6087e 100644
--- a/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
+++ b/mlir/lib/Dialect/Async/Transforms/AsyncParallelFor.cpp
@@ -273,7 +273,7 @@ static ParallelComputeFunction createParallelComputeFunction(
   // Insert function into the module symbol table and assign it unique name.
   SymbolTable symbolTable(module);
   symbolTable.insert(func);
-  rewriter.getListener()->notifyOperationInserted(func);
+  rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
 
   // Create function entry block.
   Block *block =
@@ -489,7 +489,7 @@ createAsyncDispatchFunction(ParallelComputeFunction &computeFunc,
   // Insert function into the module symbol table and assign it unique name.
   SymbolTable symbolTable(module);
   symbolTable.insert(func);
-  rewriter.getListener()->notifyOperationInserted(func);
+  rewriter.getListener()->notifyOperationInserted(func, /*previous=*/{});
 
   // Create function entry block.
   Block *block = b.createBlock(&func.getBody(), func.begin(), type.getInputs(),
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
index 3f1626a6af34d49..2758d554712b9f4 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp
@@ -371,7 +371,11 @@ class BufferizationRewriter : public IRRewriter, public RewriterBase::Listener {
     toMemrefOps.erase(op);
   }
 
-  void notifyOperationInserted(Operation *op) override {
+  void notifyOperationInserted(Operation *op, InsertPoint previous) override {
+    // We only care about newly created ops.
+    if (previous.isSet())
+      return;
+
     erasedOps.erase(op);
 
     // Gather statistics about allocs.
diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
index 140bdd1f2db3618..803c5691a0403b0 100644
--- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
+++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp
@@ -214,8 +214,12 @@ class NewOpsListener : public RewriterBase::ForwardingListener {
   }
 
 private:
-  void notifyOperationInserted(Operation *op) override {
-    ForwardingListener::notifyOperationInserted(op);
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override {
+    ForwardingListener::notifyOperationInserted(op, previous);
+    // We only care about newly created ops.
+    if (previous.isSet())
+      return;
     auto inserted = newOps.insert(op);
     (void)inserted;
     assert(inserted.second && "expected newly created op");
diff --git a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
index cda561b1d1054d9..9f8189ae15e6de2 100644
--- a/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/ForToWhile.cpp
@@ -83,7 +83,7 @@ struct ForLoopLoweringPattern : public OpRewritePattern<ForOp> {
 
     // Inline for-loop body operations into 'after' region.
     for (auto &arg : llvm::make_early_inc_range(*forOp.getBody()))
-      arg.moveBefore(afterBlock, afterBlock->end());
+      rewriter.moveOpBefore(&arg, afterBlock, afterBlock->end());
 
     // Add incremented IV to yield operations
     for (auto yieldOp : afterBlock->getOps<scf::YieldOp>()) {
diff --git a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
index 2cd57e7324b4dc5..678b7c099fa3692 100644
--- a/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
+++ b/mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp
@@ -983,7 +983,7 @@ struct ParallelInsertSliceOpInterface
     for (Operation *user : srcBuffer->getUsers()) {
       if (hasEffect<MemoryEffects::Free>(user)) {
         if (user->getBlock() == parallelCombiningParent->getBlock())
-          user->moveBefore(user->getBlock()->getTerminator());
+          rewriter.moveOpBefore(user, user->getBlock()->getTerminator());
         break;
       }
     }
diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp
index d1565047658776d..a319afcdc6a9a23 100644
--- a/mlir/lib/IR/Builders.cpp
+++ b/mlir/lib/IR/Builders.cpp
@@ -412,7 +412,7 @@ Operation *OpBuilder::insert(Operation *op) {
     block->getOperations().insert(insertPoint, op);
 
   if (listener)
-    listener->notifyOperationInserted(op);
+    listener->notifyOperationInserted(op, /*previous=*/{});
   return op;
 }
 
@@ -530,7 +530,7 @@ Operation *OpBuilder::clone(Operation &op, IRMapping &mapper) {
   // about any ops that got inserted inside those regions as part of cloning.
   if (listener) {
     auto walkFn = [&](Operation *walkedOp) {
-      listener->notifyOperationInserted(walkedOp);
+      listener->notifyOperationInserted(walkedOp, /*previous=*/{});
     };
     for (Region &region : newOp->getRegions())
       region.walk<WalkOrder::PreOrder>(walkFn);
diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp
index ba0516e0539b6ca..affb8898fa07544 100644
--- a/mlir/lib/IR/PatternMatch.cpp
+++ b/mlir/lib/IR/PatternMatch.cpp
@@ -366,3 +366,31 @@ void RewriterBase::cloneRegionBefore(Region &region, Region &parent,
 void RewriterBase::cloneRegionBefore(Region &region, Block *before) {
   cloneRegionBefore(region, *before->getParent(), before->getIterator());
 }
+
+void RewriterBase::moveOpBefore(Operation *op, Operation *existingOp) {
+  moveOpBefore(op, existingOp->getBlock(), existingOp->getIterator());
+}
+
+void RewriterBase::moveOpBefore(Operation *op, Block *block,
+                                Block::iterator iterator) {
+  Block *currentBlock = op->getBlock();
+  Block::iterator currentIterator = op->getIterator();
+  op->moveBefore(block, iterator);
+  if (listener)
+    listener->notifyOperationInserted(
+        op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+}
+
+void RewriterBase::moveOpAfter(Operation *op, Operation *existingOp) {
+  moveOpAfter(op, existingOp->getBlock(), existingOp->getIterator());
+}
+
+void RewriterBase::moveOpAfter(Operation *op, Block *block,
+                               Block::iterator iterator) {
+  Block *currentBlock = op->getBlock();
+  Block::iterator currentIterator = op->getIterator();
+  op->moveAfter(block, iterator);
+  if (listener)
+    listener->notifyOperationInserted(
+        op, /*previous=*/InsertPoint(currentBlock, currentIterator));
+}
diff --git a/mlir/lib/Transforms/Utils/DialectConversion.cpp b/mlir/lib/Transforms/Utils/DialectConversion.cpp
index ef6a49455d18605..f5bede2b94f9cb2 100644
--- a/mlir/lib/Transforms/Utils/DialectConversion.cpp
+++ b/mlir/lib/Transforms/Utils/DialectConversion.cpp
@@ -1602,11 +1602,13 @@ void ConversionPatternRewriter::cloneRegionBefore(Region &region,
     Block *cloned = mapping.lookup(&b);
     impl->notifyCreatedBlock(cloned);
     cloned->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>(
-        [&](Operation *op) { notifyOperationInserted(op); });
+        [&](Operation *op) { notifyOperationInserted(op, /*previous=*/{}); });
   }
 }
 
-void ConversionPatternRewriter::notifyOperationInserted(Operation *op) {
+void ConversionPatternRewriter::notifyOperationInserted(Operation *op,
+                                                        InsertPoint previous) {
+  assert(!previous.isSet() && "expected newly created op");
   LLVM_DEBUG({
     impl->logger.startLine()
         << "** Insert  : '" << op->getName() << "'(" << op << ")\n";
@@ -1651,6 +1653,18 @@ LogicalResult ConversionPatternRewriter::notifyMatchFailure(
   return impl->notifyMatchFailure(loc, reasonCallback);
 }
 
+void ConversionPatternRewriter::moveOpBefore(Operation *op, Block *block,
+                                             Block::iterator iterator) {
+  llvm_unreachable(
+      "moving single ops is not supported in a dialect conversion");
+}
+
+void ConversionPatternRewriter::moveOpAfter(Operation *op, Block *block,
+                                            Block::iterator iterator) {
+  llvm_unreachable(
+      "moving single ops is not supported in a dialect conversion");
+}
+
 detail::ConversionPatternRewriterImpl &ConversionPatternRewriter::getImpl() {
   return *impl;
 }
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index ac73e82bfe92a69..c27fee7a738eba0 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -133,9 +133,16 @@ struct ExpensiveChecks : public RewriterBase::ForwardingListener {
     }
   }
 
-  void notifyOperationInserted(Operation *op) override {
-    RewriterBase::ForwardingListener::notifyOperationInserted(op);
+  void notifyOperationInserted(Operation *op, InsertPoint previous) override {
+    RewriterBase::ForwardingListener::notifyOperationInserted(op, previous);
+    // Invalidate the finger print of the op that owns the block into which the
+    // op was inserted into.
     invalidateFingerPrint(op->getParentOp());
+
+    // Also invalidate the finger print of the op that owns the block from which
+    // the op was moved from. (Only applicable if the op was moved.)
+    if (previous.isSet())
+      invalidateFingerPrint(previous.getBlock()->getParentOp());
   }
 
   void notifyOperationModified(Operation *op) override {
@@ -331,7 +338,7 @@ class GreedyPatternRewriteDriver : public PatternRewriter,
   /// Notify the driver that the specified operation was inserted. Update the
   /// worklist as needed: The operation is enqueued depending on scope and
   /// strict mode.
-  void notifyOperationInserted(Operation *op) override;
+  void notifyOperationInserted(Operation *op, InsertPoint previous) override;
 
   /// Notify the driver that the specified operation was removed. Update the
   /// worklist as needed: The operation and its children are removed from the
@@ -641,13 +648,14 @@ void GreedyPatternRewriteDriver::notifyBlockRemoved(Block *block) {
     config.listener->notifyBlockRemoved(block);
 }
 
-void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op) {
+void GreedyPatternRewriteDriver::notifyOperationInserted(Operation *op,
+                                                         InsertPoint previous) {
   LLVM_DEBUG({
     logger.startLine() << "** Insert  : '" << op->getName() << "'(" << op
                        << ")\n";
   });
   if (config.listener)
-    config.listener->notifyOperationInserted(op);
+    config.listener->notifyOperationInserted(op, previous);
   if (config.strictMode == GreedyRewriteStrictness::ExistingAndNewOps)
     strictModeFilteredOps.insert(op);
   addToWorklist(op);
diff --git a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
index 8f97fd3d9ddf84e..66ce6067963f838 100644
--- a/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
+++ b/mlir/lib/Transforms/Utils/LoopInvariantCodeMotionUtils.cpp
@@ -365,8 +365,8 @@ static LoopLikeOpInterface hoistSubsetAtIterArg(RewriterBase &rewriter,
       iterArg = loopLike.getRegionIterArgs()[iterArgIdx];
       OpResult loopResult = loopLike.getTiedLoopResult(iterArg);
       OpResult newLoopResult = loopLike.getLoopResults()->back();
-      extractionOp->moveBefore(loopLike);
-      insertionOp->moveAfter(loopLike);
+      rewriter.moveOpBefore(extractionOp, loopLike);
+      rewriter.moveOpAfter(insertionOp, loopLike);
       rewriter.replaceAllUsesWith(insertionOp.getUpdatedDestination(),
                                   insertionOp.getDestinationOperand().get());
       extractionOp.getSourceOperand().set(
diff --git a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
index a8a808424b690f6..8a92d840ad13026 100644
--- a/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
+++ b/mlir/test/lib/Dialect/SCF/TestSCFUtils.cpp
@@ -159,8 +159,8 @@ struct TestSCFPipeliningPass
     auto ifOp =
         rewriter.create<scf::IfOp>(loc, op->getResultTypes(), pred, true);
     // True branch.
-    op->moveBefore(&ifOp.getThenRegion().front(),
-                   ifOp.getThenRegion().front().begin());
+    rewriter.moveOpBefore(op, &ifOp.getThenRegion().front(),
+                          ifOp.getThenRegion().front().begin());
     rewriter.setInsertionPointAfter(op);
     if (op->getNumResults() > 0)
       rewriter.create<scf::YieldOp>(loc, op->getResults());
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index d1ac5e81e75a695..89b9d1ce78a52b6 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -193,9 +193,7 @@ struct HoistEligibleOps : public OpRewritePattern<test::OneRegionOp> {
       return failure();
     if (!toBeHoisted->hasAttr("eligible"))
       return failure();
-    // Hoisting means removing an op from the enclosing op. I.e., the enclosing
-    // op is modified.
-    rewriter.modifyOpInPlace(op, [&]() { toBeHoisted->moveBefore(op); });
+    rewriter.moveOpBefore(toBeHoisted, op);
     return success();
   }
 };
diff --git a/mlir/test/lib/IR/TestClone.cpp b/mlir/test/lib/IR/TestClone.cpp
index 13a0cfeb402a9cd..7b18f219b915f46 100644
--- a/mlir/test/lib/IR/TestClone.cpp
+++ b/mlir/test/lib/IR/TestClone.cpp
@@ -15,7 +15,8 @@ using namespace mlir;
 namespace {
 
 struct DumpNotifications : public OpBuilder::Listener {
-  void notifyOperationInserted(Operation *op) override {
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override {
     llvm::outs() << "notifyOperationInserted: " << op->getName() << "\n";
   }
 };
diff --git a/mlir/test/lib/Transforms/TestConstantFold.cpp b/mlir/test/lib/Transforms/TestConstantFold.cpp
index aa67e0a78d43b7a..81f634d8d3ef78f 100644
--- a/mlir/test/lib/Transforms/TestConstantFold.cpp
+++ b/mlir/test/lib/Transforms/TestConstantFold.cpp
@@ -27,7 +27,8 @@ struct TestConstantFold : public PassWrapper<TestConstantFold, OperationPass<>>,
   void foldOperation(Operation *op, OperationFolder &helper);
   void runOnOperation() override;
 
-  void notifyOperationInserted(Operation *op) override {
+  void notifyOperationInserted(Operation *op,
+                               OpBuilder::InsertPoint previous) override {
     existingConstants.push_back(op);
   }
   void notifyOperationRemoved(Operation *op) override {



More information about the flang-commits mailing list